chiark / gitweb /
util: fix fd_cloexec(), fd_nonblock()
[elogind.git] / src / socket-proxy / socket-proxyd.c
1 /*-*- Mode: C; c-basic-offset: 8; indent-tabs-mode: nil -*-*/
2
3 /***
4   This file is part of systemd.
5
6   Copyright 2013 David Strauss
7
8   systemd is free software; you can redistribute it and/or modify it
9   under the terms of the GNU Lesser General Public License as published by
10   the Free Software Foundation; either version 2.1 of the License, or
11   (at your option) any later version.
12
13   systemd is distributed in the hope that it will be useful, but
14   WITHOUT ANY WARRANTY; without even the implied warranty of
15   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
16   Lesser General Public License for more details.
17
18   You should have received a copy of the GNU Lesser General Public License
19   along with systemd; If not, see <http://www.gnu.org/licenses/>.
20  ***/
21
22 #include <arpa/inet.h>
23 #include <errno.h>
24 #include <getopt.h>
25 #include <stdio.h>
26 #include <stdlib.h>
27 #include <string.h>
28 #include <netdb.h>
29 #include <sys/fcntl.h>
30 #include <sys/socket.h>
31 #include <sys/un.h>
32 #include <unistd.h>
33
34 #include "sd-daemon.h"
35 #include "sd-event.h"
36 #include "sd-resolve.h"
37 #include "log.h"
38 #include "socket-util.h"
39 #include "util.h"
40 #include "event-util.h"
41 #include "build.h"
42 #include "set.h"
43 #include "path-util.h"
44
45 #define BUFFER_SIZE (256 * 1024)
46 #define CONNECTIONS_MAX 256
47
48 static const char *arg_remote_host = NULL;
49
50 #define _cleanup_freeaddrinfo_ _cleanup_(freeaddrinfop)
51 DEFINE_TRIVIAL_CLEANUP_FUNC(struct addrinfo *, freeaddrinfo);
52
53 typedef struct Context {
54         sd_event *event;
55         sd_resolve *resolve;
56
57         Set *listen;
58         Set *connections;
59 } Context;
60
61 typedef struct Connection {
62         Context *context;
63
64         int server_fd, client_fd;
65         int server_to_client_buffer[2]; /* a pipe */
66         int client_to_server_buffer[2]; /* a pipe */
67
68         size_t server_to_client_buffer_full, client_to_server_buffer_full;
69         size_t server_to_client_buffer_size, client_to_server_buffer_size;
70
71         sd_event_source *server_event_source, *client_event_source;
72
73         sd_resolve_query *resolve_query;
74 } Connection;
75
76 static void connection_free(Connection *c) {
77         assert(c);
78
79         if (c->context)
80                 set_remove(c->context->connections, c);
81
82         sd_event_source_unref(c->server_event_source);
83         sd_event_source_unref(c->client_event_source);
84
85         safe_close(c->server_fd);
86         safe_close(c->client_fd);
87
88         safe_close_pair(c->server_to_client_buffer);
89         safe_close_pair(c->client_to_server_buffer);
90
91         sd_resolve_query_unref(c->resolve_query);
92
93         free(c);
94 }
95
96 static void context_free(Context *context) {
97         sd_event_source *es;
98         Connection *c;
99
100         assert(context);
101
102         while ((es = set_steal_first(context->listen)))
103                 sd_event_source_unref(es);
104
105         while ((c = set_first(context->connections)))
106                 connection_free(c);
107
108         set_free(context->listen);
109         set_free(context->connections);
110
111         sd_event_unref(context->event);
112         sd_resolve_unref(context->resolve);
113 }
114
115 static int connection_create_pipes(Connection *c, int buffer[2], size_t *sz) {
116         int r;
117
118         assert(c);
119         assert(buffer);
120         assert(sz);
121
122         if (buffer[0] >= 0)
123                 return 0;
124
125         r = pipe2(buffer, O_CLOEXEC|O_NONBLOCK);
126         if (r < 0) {
127                 log_error("Failed to allocate pipe buffer: %m");
128                 return -errno;
129         }
130
131         fcntl(buffer[0], F_SETPIPE_SZ, BUFFER_SIZE);
132
133         r = fcntl(buffer[0], F_GETPIPE_SZ);
134         if (r < 0) {
135                 log_error("Failed to get pipe buffer size: %m");
136                 return -errno;
137         }
138
139         assert(r > 0);
140         *sz = r;
141
142         return 0;
143 }
144
145 static int connection_shovel(
146                 Connection *c,
147                 int *from, int buffer[2], int *to,
148                 size_t *full, size_t *sz,
149                 sd_event_source **from_source, sd_event_source **to_source) {
150
151         bool shoveled;
152
153         assert(c);
154         assert(from);
155         assert(buffer);
156         assert(buffer[0] >= 0);
157         assert(buffer[1] >= 0);
158         assert(to);
159         assert(full);
160         assert(sz);
161         assert(from_source);
162         assert(to_source);
163
164         do {
165                 ssize_t z;
166
167                 shoveled = false;
168
169                 if (*full < *sz && *from >= 0 && *to >= 0) {
170                         z = splice(*from, NULL, buffer[1], NULL, *sz - *full, SPLICE_F_MOVE|SPLICE_F_NONBLOCK);
171                         if (z > 0) {
172                                 *full += z;
173                                 shoveled = true;
174                         } else if (z == 0 || errno == EPIPE || errno == ECONNRESET) {
175                                 *from_source = sd_event_source_unref(*from_source);
176                                 *from = safe_close(*from);
177                         } else if (errno != EAGAIN && errno != EINTR) {
178                                 log_error("Failed to splice: %m");
179                                 return -errno;
180                         }
181                 }
182
183                 if (*full > 0 && *to >= 0) {
184                         z = splice(buffer[0], NULL, *to, NULL, *full, SPLICE_F_MOVE|SPLICE_F_NONBLOCK);
185                         if (z > 0) {
186                                 *full -= z;
187                                 shoveled = true;
188                         } else if (z == 0 || errno == EPIPE || errno == ECONNRESET) {
189                                 *to_source = sd_event_source_unref(*to_source);
190                                 *to = safe_close(*to);
191                         } else if (errno != EAGAIN && errno != EINTR) {
192                                 log_error("Failed to splice: %m");
193                                 return -errno;
194                         }
195                 }
196         } while (shoveled);
197
198         return 0;
199 }
200
201 static int connection_enable_event_sources(Connection *c);
202
203 static int traffic_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) {
204         Connection *c = userdata;
205         int r;
206
207         assert(s);
208         assert(fd >= 0);
209         assert(c);
210
211         r = connection_shovel(c,
212                               &c->server_fd, c->server_to_client_buffer, &c->client_fd,
213                               &c->server_to_client_buffer_full, &c->server_to_client_buffer_size,
214                               &c->server_event_source, &c->client_event_source);
215         if (r < 0)
216                 goto quit;
217
218         r = connection_shovel(c,
219                               &c->client_fd, c->client_to_server_buffer, &c->server_fd,
220                               &c->client_to_server_buffer_full, &c->client_to_server_buffer_size,
221                               &c->client_event_source, &c->server_event_source);
222         if (r < 0)
223                 goto quit;
224
225         /* EOF on both sides? */
226         if (c->server_fd == -1 && c->client_fd == -1)
227                 goto quit;
228
229         /* Server closed, and all data written to client? */
230         if (c->server_fd == -1 && c->server_to_client_buffer_full <= 0)
231                 goto quit;
232
233         /* Client closed, and all data written to server? */
234         if (c->client_fd == -1 && c->client_to_server_buffer_full <= 0)
235                 goto quit;
236
237         r = connection_enable_event_sources(c);
238         if (r < 0)
239                 goto quit;
240
241         return 1;
242
243 quit:
244         connection_free(c);
245         return 0; /* ignore errors, continue serving */
246 }
247
248 static int connection_enable_event_sources(Connection *c) {
249         uint32_t a = 0, b = 0;
250         int r;
251
252         assert(c);
253
254         if (c->server_to_client_buffer_full > 0)
255                 b |= EPOLLOUT;
256         if (c->server_to_client_buffer_full < c->server_to_client_buffer_size)
257                 a |= EPOLLIN;
258
259         if (c->client_to_server_buffer_full > 0)
260                 a |= EPOLLOUT;
261         if (c->client_to_server_buffer_full < c->client_to_server_buffer_size)
262                 b |= EPOLLIN;
263
264         if (c->server_event_source)
265                 r = sd_event_source_set_io_events(c->server_event_source, a);
266         else if (c->server_fd >= 0)
267                 r = sd_event_add_io(c->context->event, &c->server_event_source, c->server_fd, a, traffic_cb, c);
268         else
269                 r = 0;
270
271         if (r < 0) {
272                 log_error("Failed to set up server event source: %s", strerror(-r));
273                 return r;
274         }
275
276         if (c->client_event_source)
277                 r = sd_event_source_set_io_events(c->client_event_source, b);
278         else if (c->client_fd >= 0)
279                 r = sd_event_add_io(c->context->event, &c->client_event_source, c->client_fd, b, traffic_cb, c);
280         else
281                 r = 0;
282
283         if (r < 0) {
284                 log_error("Failed to set up client event source: %s", strerror(-r));
285                 return r;
286         }
287
288         return 0;
289 }
290
291 static int connection_complete(Connection *c) {
292         int r;
293
294         assert(c);
295
296         r = connection_create_pipes(c, c->server_to_client_buffer, &c->server_to_client_buffer_size);
297         if (r < 0)
298                 goto fail;
299
300         r = connection_create_pipes(c, c->client_to_server_buffer, &c->client_to_server_buffer_size);
301         if (r < 0)
302                 goto fail;
303
304         r = connection_enable_event_sources(c);
305         if (r < 0)
306                 goto fail;
307
308         return 0;
309
310 fail:
311         connection_free(c);
312         return 0; /* ignore errors, continue serving */
313 }
314
315 static int connect_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) {
316         Connection *c = userdata;
317         socklen_t solen;
318         int error, r;
319
320         assert(s);
321         assert(fd >= 0);
322         assert(c);
323
324         solen = sizeof(error);
325         r = getsockopt(fd, SOL_SOCKET, SO_ERROR, &error, &solen);
326         if (r < 0) {
327                 log_error("Failed to issue SO_ERROR: %m");
328                 goto fail;
329         }
330
331         if (error != 0) {
332                 log_error("Failed to connect to remote host: %s", strerror(error));
333                 goto fail;
334         }
335
336         c->client_event_source = sd_event_source_unref(c->client_event_source);
337
338         return connection_complete(c);
339
340 fail:
341         connection_free(c);
342         return 0; /* ignore errors, continue serving */
343 }
344
345 static int connection_start(Connection *c, struct sockaddr *sa, socklen_t salen) {
346         int r;
347
348         assert(c);
349         assert(sa);
350         assert(salen);
351
352         c->client_fd = socket(sa->sa_family, SOCK_STREAM|SOCK_NONBLOCK|SOCK_CLOEXEC, 0);
353         if (c->client_fd < 0) {
354                 log_error("Failed to get remote socket: %m");
355                 goto fail;
356         }
357
358         r = connect(c->client_fd, sa, salen);
359         if (r < 0) {
360                 if (errno == EINPROGRESS) {
361                         r = sd_event_add_io(c->context->event, &c->client_event_source, c->client_fd, EPOLLOUT, connect_cb, c);
362                         if (r < 0) {
363                                 log_error("Failed to add connection socket: %s", strerror(-r));
364                                 goto fail;
365                         }
366
367                         r = sd_event_source_set_enabled(c->client_event_source, SD_EVENT_ONESHOT);
368                         if (r < 0) {
369                                 log_error("Failed to enable oneshot event source: %s", strerror(-r));
370                                 goto fail;
371                         }
372                 } else {
373                         log_error("Failed to connect to remote host: %m");
374                         goto fail;
375                 }
376         } else {
377                 r = connection_complete(c);
378                 if (r < 0)
379                         goto fail;
380         }
381
382         return 0;
383
384 fail:
385         connection_free(c);
386         return 0; /* ignore errors, continue serving */
387 }
388
389 static int resolve_cb(sd_resolve_query *q, int ret, const struct addrinfo *ai, void *userdata) {
390         Connection *c = userdata;
391
392         assert(q);
393         assert(c);
394
395         if (ret != 0) {
396                 log_error("Failed to resolve host: %s", gai_strerror(ret));
397                 goto fail;
398         }
399
400         c->resolve_query = sd_resolve_query_unref(c->resolve_query);
401
402         return connection_start(c, ai->ai_addr, ai->ai_addrlen);
403
404 fail:
405         connection_free(c);
406         return 0; /* ignore errors, continue serving */
407 }
408
409 static int resolve_remote(Connection *c) {
410
411         static const struct addrinfo hints = {
412                 .ai_family = AF_UNSPEC,
413                 .ai_socktype = SOCK_STREAM,
414                 .ai_flags = AI_ADDRCONFIG
415         };
416
417         union sockaddr_union sa = {};
418         const char *node, *service;
419         socklen_t salen;
420         int r;
421
422         if (path_is_absolute(arg_remote_host)) {
423                 sa.un.sun_family = AF_UNIX;
424                 strncpy(sa.un.sun_path, arg_remote_host, sizeof(sa.un.sun_path)-1);
425                 sa.un.sun_path[sizeof(sa.un.sun_path)-1] = 0;
426
427                 salen = offsetof(union sockaddr_union, un.sun_path) + strlen(sa.un.sun_path);
428
429                 return connection_start(c, &sa.sa, salen);
430         }
431
432         if (arg_remote_host[0] == '@') {
433                 sa.un.sun_family = AF_UNIX;
434                 sa.un.sun_path[0] = 0;
435                 strncpy(sa.un.sun_path+1, arg_remote_host+1, sizeof(sa.un.sun_path)-2);
436                 sa.un.sun_path[sizeof(sa.un.sun_path)-1] = 0;
437
438                 salen = offsetof(union sockaddr_union, un.sun_path) + 1 + strlen(sa.un.sun_path + 1);
439
440                 return connection_start(c, &sa.sa, salen);
441         }
442
443         service = strrchr(arg_remote_host, ':');
444         if (service) {
445                 node = strndupa(arg_remote_host, service - arg_remote_host);
446                 service ++;
447         } else {
448                 node = arg_remote_host;
449                 service = "80";
450         }
451
452         log_debug("Looking up address info for %s:%s", node, service);
453         r = sd_resolve_getaddrinfo(c->context->resolve, &c->resolve_query, node, service, &hints, resolve_cb, c);
454         if (r < 0) {
455                 log_error("Failed to resolve remote host: %s", strerror(-r));
456                 goto fail;
457         }
458
459         return 0;
460
461 fail:
462         connection_free(c);
463         return 0; /* ignore errors, continue serving */
464 }
465
466 static int add_connection_socket(Context *context, int fd) {
467         Connection *c;
468         int r;
469
470         assert(context);
471         assert(fd >= 0);
472
473         if (set_size(context->connections) > CONNECTIONS_MAX) {
474                 log_warning("Hit connection limit, refusing connection.");
475                 safe_close(fd);
476                 return 0;
477         }
478
479         r = set_ensure_allocated(&context->connections, trivial_hash_func, trivial_compare_func);
480         if (r < 0) {
481                 log_oom();
482                 return 0;
483         }
484
485         c = new0(Connection, 1);
486         if (!c) {
487                 log_oom();
488                 return 0;
489         }
490
491         c->context = context;
492         c->server_fd = fd;
493         c->client_fd = -1;
494         c->server_to_client_buffer[0] = c->server_to_client_buffer[1] = -1;
495         c->client_to_server_buffer[0] = c->client_to_server_buffer[1] = -1;
496
497         r = set_put(context->connections, c);
498         if (r < 0) {
499                 free(c);
500                 log_oom();
501                 return 0;
502         }
503
504         return resolve_remote(c);
505 }
506
507 static int accept_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) {
508         _cleanup_free_ char *peer = NULL;
509         Context *context = userdata;
510         int nfd = -1, r;
511
512         assert(s);
513         assert(fd >= 0);
514         assert(revents & EPOLLIN);
515         assert(context);
516
517         nfd = accept4(fd, NULL, NULL, SOCK_NONBLOCK|SOCK_CLOEXEC);
518         if (nfd < 0) {
519                 if (errno != -EAGAIN)
520                         log_warning("Failed to accept() socket: %m");
521         } else {
522                 getpeername_pretty(nfd, &peer);
523                 log_debug("New connection from %s", strna(peer));
524
525                 r = add_connection_socket(context, nfd);
526                 if (r < 0) {
527                         log_error("Failed to accept connection, ignoring: %s", strerror(-r));
528                         safe_close(fd);
529                 }
530         }
531
532         r = sd_event_source_set_enabled(s, SD_EVENT_ONESHOT);
533         if (r < 0) {
534                 log_error("Error while re-enabling listener with ONESHOT: %s", strerror(-r));
535                 sd_event_exit(context->event, r);
536                 return r;
537         }
538
539         return 1;
540 }
541
542 static int add_listen_socket(Context *context, int fd) {
543         sd_event_source *source;
544         int r;
545
546         assert(context);
547         assert(fd >= 0);
548
549         r = set_ensure_allocated(&context->listen, trivial_hash_func, trivial_compare_func);
550         if (r < 0) {
551                 log_oom();
552                 return r;
553         }
554
555         r = sd_is_socket(fd, 0, SOCK_STREAM, 1);
556         if (r < 0) {
557                 log_error("Failed to determine socket type: %s", strerror(-r));
558                 return r;
559         }
560         if (r == 0) {
561                 log_error("Passed in socket is not a stream socket.");
562                 return -EINVAL;
563         }
564
565         r = fd_nonblock(fd, true);
566         if (r < 0) {
567                 log_error("Failed to mark file descriptor non-blocking: %s", strerror(-r));
568                 return r;
569         }
570
571         r = sd_event_add_io(context->event, &source, fd, EPOLLIN, accept_cb, context);
572         if (r < 0) {
573                 log_error("Failed to add event source: %s", strerror(-r));
574                 return r;
575         }
576
577         r = set_put(context->listen, source);
578         if (r < 0) {
579                 log_error("Failed to add source to set: %s", strerror(-r));
580                 sd_event_source_unref(source);
581                 return r;
582         }
583
584         /* Set the watcher to oneshot in case other processes are also
585          * watching to accept(). */
586         r = sd_event_source_set_enabled(source, SD_EVENT_ONESHOT);
587         if (r < 0) {
588                 log_error("Failed to enable oneshot mode: %s", strerror(-r));
589                 return r;
590         }
591
592         return 0;
593 }
594
595 static int help(void) {
596
597         printf("%s [HOST:PORT]\n"
598                "%s [SOCKET]\n\n"
599                "Bidirectionally proxy local sockets to another (possibly remote) socket.\n\n"
600                "  -h --help              Show this help\n"
601                "     --version           Show package version\n",
602                program_invocation_short_name,
603                program_invocation_short_name);
604
605         return 0;
606 }
607
608 static int parse_argv(int argc, char *argv[]) {
609
610         enum {
611                 ARG_VERSION = 0x100,
612                 ARG_IGNORE_ENV
613         };
614
615         static const struct option options[] = {
616                 { "help",       no_argument, NULL, 'h'           },
617                 { "version",    no_argument, NULL, ARG_VERSION   },
618                 {}
619         };
620
621         int c;
622
623         assert(argc >= 0);
624         assert(argv);
625
626         while ((c = getopt_long(argc, argv, "h", options, NULL)) >= 0) {
627
628                 switch (c) {
629
630                 case 'h':
631                         return help();
632
633                 case ARG_VERSION:
634                         puts(PACKAGE_STRING);
635                         puts(SYSTEMD_FEATURES);
636                         return 0;
637
638                 case '?':
639                         return -EINVAL;
640
641                 default:
642                         assert_not_reached("Unhandled option");
643                 }
644         }
645
646         if (optind >= argc) {
647                 log_error("Not enough parameters.");
648                 return -EINVAL;
649         }
650
651         if (argc != optind+1) {
652                 log_error("Too many parameters.");
653                 return -EINVAL;
654         }
655
656         arg_remote_host = argv[optind];
657         return 1;
658 }
659
660 int main(int argc, char *argv[]) {
661         Context context = {};
662         int r, n, fd;
663
664         log_parse_environment();
665         log_open();
666
667         r = parse_argv(argc, argv);
668         if (r <= 0)
669                 goto finish;
670
671         r = sd_event_default(&context.event);
672         if (r < 0) {
673                 log_error("Failed to allocate event loop: %s", strerror(-r));
674                 goto finish;
675         }
676
677         r = sd_resolve_default(&context.resolve);
678         if (r < 0) {
679                 log_error("Failed to allocate resolver: %s", strerror(-r));
680                 goto finish;
681         }
682
683         r = sd_resolve_attach_event(context.resolve, context.event, 0);
684         if (r < 0) {
685                 log_error("Failed to attach resolver: %s", strerror(-r));
686                 goto finish;
687         }
688
689         sd_event_set_watchdog(context.event, true);
690
691         n = sd_listen_fds(1);
692         if (n < 0) {
693                 log_error("Failed to receive sockets from parent.");
694                 r = n;
695                 goto finish;
696         } else if (n == 0) {
697                 log_error("Didn't get any sockets passed in.");
698                 r = -EINVAL;
699                 goto finish;
700         }
701
702         for (fd = SD_LISTEN_FDS_START; fd < SD_LISTEN_FDS_START + n; fd++) {
703                 r = add_listen_socket(&context, fd);
704                 if (r < 0)
705                         goto finish;
706         }
707
708         r = sd_event_loop(context.event);
709         if (r < 0) {
710                 log_error("Failed to run event loop: %s", strerror(-r));
711                 goto finish;
712         }
713
714 finish:
715         context_free(&context);
716
717         return r < 0 ? EXIT_FAILURE : EXIT_SUCCESS;
718 }