chiark / gitweb /
socket-proxyd: rework to support multiple sockets and splice()-based zero-copy network IO
[elogind.git] / src / socket-proxy / socket-proxyd.c
1 /*-*- Mode: C; c-basic-offset: 8; indent-tabs-mode: nil -*-*/
2
3 /***
4   This file is part of systemd.
5
6   Copyright 2013 David Strauss
7
8   systemd is free software; you can redistribute it and/or modify it
9   under the terms of the GNU Lesser General Public License as published by
10   the Free Software Foundation; either version 2.1 of the License, or
11   (at your option) any later version.
12
13   systemd is distributed in the hope that it will be useful, but
14   WITHOUT ANY WARRANTY; without even the implied warranty of
15   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
16   Lesser General Public License for more details.
17
18   You should have received a copy of the GNU Lesser General Public License
19   along with systemd; If not, see <http://www.gnu.org/licenses/>.
20  ***/
21
22 #include <arpa/inet.h>
23 #include <errno.h>
24 #include <getopt.h>
25 #include <stdio.h>
26 #include <stdlib.h>
27 #include <string.h>
28 #include <netdb.h>
29 #include <sys/fcntl.h>
30 #include <sys/socket.h>
31 #include <sys/un.h>
32 #include <unistd.h>
33
34 #include "sd-daemon.h"
35 #include "sd-event.h"
36 #include "log.h"
37 #include "socket-util.h"
38 #include "util.h"
39 #include "event-util.h"
40 #include "build.h"
41 #include "set.h"
42 #include "path-util.h"
43
44 #define BUFFER_SIZE (256 * 1024)
45 #define CONNECTIONS_MAX 256
46
47 #define _cleanup_freeaddrinfo_ _cleanup_(freeaddrinfop)
48 DEFINE_TRIVIAL_CLEANUP_FUNC(struct addrinfo *, freeaddrinfo);
49
50 typedef struct Context {
51         Set *listen;
52         Set *connections;
53 } Context;
54
55 typedef struct Connection {
56         int server_fd, client_fd;
57         int server_to_client_buffer[2]; /* a pipe */
58         int client_to_server_buffer[2]; /* a pipe */
59
60         size_t server_to_client_buffer_full, client_to_server_buffer_full;
61         size_t server_to_client_buffer_size, client_to_server_buffer_size;
62
63         sd_event_source *server_event_source, *client_event_source;
64 } Connection;
65
66 union sockaddr_any {
67         struct sockaddr sa;
68         struct sockaddr_un un;
69         struct sockaddr_in in;
70         struct sockaddr_in6 in6;
71         struct sockaddr_storage storage;
72 };
73
74 static const char *arg_remote_host = NULL;
75
76 static void connection_free(Connection *c) {
77         assert(c);
78
79         sd_event_source_unref(c->server_event_source);
80         sd_event_source_unref(c->client_event_source);
81
82         if (c->server_fd >= 0)
83                 close_nointr_nofail(c->server_fd);
84         if (c->client_fd >= 0)
85                 close_nointr_nofail(c->client_fd);
86
87         close_pipe(c->server_to_client_buffer);
88         close_pipe(c->client_to_server_buffer);
89
90         free(c);
91 }
92
93 static void context_free(Context *context) {
94         sd_event_source *es;
95         Connection *c;
96
97         assert(context);
98
99         while ((es = set_steal_first(context->listen)))
100                 sd_event_source_unref(es);
101
102         while ((c = set_steal_first(context->connections)))
103                 connection_free(c);
104
105         set_free(context->listen);
106         set_free(context->connections);
107 }
108
109 static int get_remote_sockaddr(union sockaddr_any *sa, socklen_t *salen) {
110         int r;
111
112         assert(sa);
113         assert(salen);
114
115         if (path_is_absolute(arg_remote_host)) {
116                 sa->un.sun_family = AF_UNIX;
117                 strncpy(sa->un.sun_path, arg_remote_host, sizeof(sa->un.sun_path)-1);
118                 sa->un.sun_path[sizeof(sa->un.sun_path)-1] = 0;
119
120                 *salen = offsetof(union sockaddr_any, un.sun_path) + strlen(sa->un.sun_path);
121
122         } else if (arg_remote_host[0] == '@') {
123                 sa->un.sun_family = AF_UNIX;
124                 sa->un.sun_path[0] = 0;
125                 strncpy(sa->un.sun_path+1, arg_remote_host+1, sizeof(sa->un.sun_path)-2);
126                 sa->un.sun_path[sizeof(sa->un.sun_path)-1] = 0;
127
128                 *salen = offsetof(union sockaddr_any, un.sun_path) + 1 + strlen(sa->un.sun_path + 1);
129
130         } else {
131                 _cleanup_freeaddrinfo_ struct addrinfo *result = NULL;
132                 const char *node, *service;
133
134                 struct addrinfo hints = {
135                         .ai_family = AF_UNSPEC,
136                         .ai_socktype = SOCK_STREAM,
137                         .ai_flags = AI_ADDRCONFIG
138                 };
139
140                 service = strrchr(arg_remote_host, ':');
141                 if (service) {
142                         node = strndupa(arg_remote_host, service - arg_remote_host);
143                         service ++;
144                 } else {
145                         node = arg_remote_host;
146                         service = "80";
147                 }
148
149                 log_debug("Looking up address info for %s:%s", node, service);
150                 r = getaddrinfo(node, service, &hints, &result);
151                 if (r != 0) {
152                         log_error("Failed to resolve host %s:%s: %s", node, service, gai_strerror(r));
153                         return -EHOSTUNREACH;
154                 }
155
156                 assert(result);
157                 if (result->ai_addrlen > sizeof(union sockaddr_any)) {
158                         log_error("Address too long.");
159                         return -E2BIG;
160                 }
161
162                 memcpy(sa, result->ai_addr, result->ai_addrlen);
163                 *salen = result->ai_addrlen;
164         }
165
166         return 0;
167 }
168
169 static int connection_create_pipes(Connection *c, int buffer[2], size_t *sz) {
170         int r;
171
172         assert(c);
173         assert(buffer);
174         assert(sz);
175
176         if (buffer[0] >= 0)
177                 return 0;
178
179         r = pipe2(buffer, O_CLOEXEC|O_NONBLOCK);
180         if (r < 0) {
181                 log_error("Failed to allocate pipe buffer: %m");
182                 return -errno;
183         }
184
185         fcntl(buffer[0], F_SETPIPE_SZ, BUFFER_SIZE);
186
187         r = fcntl(buffer[0], F_GETPIPE_SZ);
188         if (r < 0) {
189                 log_error("Failed to get pipe buffer size: %m");
190                 return -errno;
191         }
192
193         assert(r > 0);
194         *sz = r;
195
196         return 0;
197 }
198
199 static int connection_shovel(
200                 Connection *c,
201                 int *from, int buffer[2], int *to,
202                 size_t *full, size_t *sz,
203                 sd_event_source **from_source, sd_event_source **to_source) {
204
205         bool shoveled;
206
207         assert(c);
208         assert(from);
209         assert(buffer);
210         assert(buffer[0] >= 0);
211         assert(buffer[1] >= 0);
212         assert(to);
213         assert(full);
214         assert(sz);
215         assert(from_source);
216         assert(to_source);
217
218         do {
219                 ssize_t z;
220
221                 shoveled = false;
222
223                 if (*full < *sz && *from >= 0 && *to >= 0) {
224                         z = splice(*from, NULL, buffer[1], NULL, *sz - *full, SPLICE_F_MOVE|SPLICE_F_NONBLOCK);
225                         if (z > 0) {
226                                 *full += z;
227                                 shoveled = true;
228                         } else if (z == 0 || errno == EPIPE || errno == ECONNRESET) {
229                                 *from_source = sd_event_source_unref(*from_source);
230                                 close_nointr_nofail(*from);
231                                 *from = -1;
232                         } else if (errno != EAGAIN && errno != EINTR) {
233                                 log_error("Failed to splice: %m");
234                                 return -errno;
235                         }
236                 }
237
238                 if (*full > 0 && *to >= 0) {
239                         z = splice(buffer[0], NULL, *to, NULL, *full, SPLICE_F_MOVE|SPLICE_F_NONBLOCK);
240                         if (z > 0) {
241                                 *full -= z;
242                                 shoveled = true;
243                         } else if (z == 0 || errno == EPIPE || errno == ECONNRESET) {
244                                 *to_source = sd_event_source_unref(*to_source);
245                                 close_nointr_nofail(*to);
246                                 *to = -1;
247                         } else if (errno != EAGAIN && errno != EINTR) {
248                                 log_error("Failed to splice: %m");
249                                 return -errno;
250                         }
251                 }
252         } while (shoveled);
253
254         return 0;
255 }
256
257 static int connection_enable_event_sources(Connection *c, sd_event *event);
258
259 static int traffic_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) {
260         Connection *c = userdata;
261         int r;
262
263         assert(s);
264         assert(fd >= 0);
265         assert(c);
266
267         r = connection_shovel(c,
268                               &c->server_fd, c->server_to_client_buffer, &c->client_fd,
269                               &c->server_to_client_buffer_full, &c->server_to_client_buffer_size,
270                               &c->server_event_source, &c->client_event_source);
271         if (r < 0)
272                 goto quit;
273
274         r = connection_shovel(c,
275                               &c->client_fd, c->client_to_server_buffer, &c->server_fd,
276                               &c->client_to_server_buffer_full, &c->client_to_server_buffer_size,
277                               &c->client_event_source, &c->server_event_source);
278         if (r < 0)
279                 goto quit;
280
281         /* EOF on both sides? */
282         if (c->server_fd == -1 && c->client_fd == -1)
283                 goto quit;
284
285         /* Server closed, and all data written to client? */
286         if (c->server_fd == -1 && c->server_to_client_buffer_full <= 0)
287                 goto quit;
288
289         /* Client closed, and all data written to server? */
290         if (c->client_fd == -1 && c->client_to_server_buffer_full <= 0)
291                 goto quit;
292
293         r = connection_enable_event_sources(c, sd_event_get(s));
294         if (r < 0)
295                 goto quit;
296
297         return 1;
298
299 quit:
300         connection_free(c);
301         return 0; /* ignore errors, continue serving */
302 }
303
304 static int connection_enable_event_sources(Connection *c, sd_event *event) {
305         uint32_t a = 0, b = 0;
306         int r;
307
308         assert(c);
309         assert(event);
310
311         if (c->server_to_client_buffer_full > 0)
312                 b |= EPOLLOUT;
313         if (c->server_to_client_buffer_full < c->server_to_client_buffer_size)
314                 a |= EPOLLIN;
315
316         if (c->client_to_server_buffer_full > 0)
317                 a |= EPOLLOUT;
318         if (c->client_to_server_buffer_full < c->client_to_server_buffer_size)
319                 b |= EPOLLIN;
320
321         if (c->server_event_source)
322                 r = sd_event_source_set_io_events(c->server_event_source, a);
323         else if (c->server_fd >= 0)
324                 r = sd_event_add_io(event, c->server_fd, a, traffic_cb, c, &c->server_event_source);
325         else
326                 r = 0;
327
328         if (r < 0) {
329                 log_error("Failed to set up server event source: %s", strerror(-r));
330                 return r;
331         }
332
333         if (c->client_event_source)
334                 r = sd_event_source_set_io_events(c->client_event_source, b);
335         else if (c->client_fd >= 0)
336                 r = sd_event_add_io(event, c->client_fd, b, traffic_cb, c, &c->client_event_source);
337         else
338                 r = 0;
339
340         if (r < 0) {
341                 log_error("Failed to set up server event source: %s", strerror(-r));
342                 return r;
343         }
344
345         return 0;
346 }
347
348 static int connect_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) {
349         Connection *c = userdata;
350         socklen_t solen;
351         int error, r;
352
353         assert(s);
354         assert(fd >= 0);
355         assert(c);
356
357         solen = sizeof(error);
358         r = getsockopt(fd, SOL_SOCKET, SO_ERROR, &error, &solen);
359         if (r < 0) {
360                 log_error("Failed to issue SO_ERROR: %m");
361                 goto fail;
362         }
363
364         if (error != 0) {
365                 log_error("Failed to connect to remote host: %s", strerror(error));
366                 goto fail;
367         }
368
369         c->client_event_source = sd_event_source_unref(c->client_event_source);
370
371         r = connection_create_pipes(c, c->server_to_client_buffer, &c->server_to_client_buffer_size);
372         if (r < 0)
373                 goto fail;
374
375         r = connection_create_pipes(c, c->client_to_server_buffer, &c->client_to_server_buffer_size);
376         if (r < 0)
377                 goto fail;
378
379         r = connection_enable_event_sources(c, sd_event_get(s));
380         if (r < 0)
381                 goto fail;
382
383         return 0;
384
385 fail:
386         connection_free(c);
387         return 0; /* ignore errors, continue serving */
388 }
389
390 static int add_connection_socket(Context *context, sd_event *event, int fd) {
391         union sockaddr_any sa = {};
392         socklen_t salen;
393         Connection *c;
394         int r;
395
396         assert(context);
397         assert(event);
398         assert(fd >= 0);
399
400         if (set_size(context->connections) > CONNECTIONS_MAX) {
401                 log_warning("Hit connection limit, refusing connection.");
402                 close_nointr_nofail(fd);
403                 return 0;
404         }
405
406         r = set_ensure_allocated(&context->connections, trivial_hash_func, trivial_compare_func);
407         if (r < 0)
408                 return log_oom();
409
410         c = new0(Connection, 1);
411         if (!c)
412                 return log_oom();
413
414         c->server_fd = fd;
415         c->client_fd = -1;
416         c->server_to_client_buffer[0] = c->server_to_client_buffer[1] = -1;
417         c->client_to_server_buffer[0] = c->client_to_server_buffer[1] = -1;
418
419         r = get_remote_sockaddr(&sa, &salen);
420         if (r < 0)
421                 goto fail;
422
423         c->client_fd = socket(sa.sa.sa_family, SOCK_STREAM|SOCK_NONBLOCK|SOCK_CLOEXEC, 0);
424         if (c->client_fd < 0) {
425                 log_error("Failed to get remote socket: %m");
426                 goto fail;
427         }
428
429         r = connect(c->client_fd, &sa.sa, salen);
430         if (r < 0) {
431                 if (errno == EINPROGRESS) {
432                         r = sd_event_add_io(event, c->client_fd, EPOLLOUT, connect_cb, c, &c->client_event_source);
433                         if (r < 0) {
434                                 log_error("Failed to add connection socket: %s", strerror(-r));
435                                 goto fail;
436                         }
437                 } else {
438                         log_error("Failed to connect to remote host: %m");
439                         goto fail;
440                 }
441         } else {
442                 r = connection_enable_event_sources(c, event);
443                 if (r < 0)
444                         goto fail;
445         }
446
447         return 0;
448
449 fail:
450         connection_free(c);
451         return 0; /* ignore non-OOM errors, continue serving */
452 }
453
454 static int accept_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) {
455         Context *context = userdata;
456         int nfd = -1, r;
457
458         assert(s);
459         assert(fd >= 0);
460         assert(revents & EPOLLIN);
461         assert(context);
462
463         nfd = accept4(fd, NULL, NULL, SOCK_NONBLOCK|SOCK_CLOEXEC);
464         if (nfd >= 0) {
465                 _cleanup_free_ char *peer = NULL;
466
467                 getpeername_pretty(nfd, &peer);
468                 log_debug("New connection from %s", strna(peer));
469
470                 r = add_connection_socket(context, sd_event_get(s), nfd);
471                 if (r < 0) {
472                         close_nointr_nofail(fd);
473                         return r;
474                 }
475
476         } else if (errno != -EAGAIN)
477                 log_warning("Failed to accept() socket: %m");
478
479         r = sd_event_source_set_enabled(s, SD_EVENT_ONESHOT);
480         if (r < 0) {
481                 log_error("Error %d while re-enabling listener with ONESHOT: %s", r, strerror(-r));
482                 return r;
483         }
484
485         return 1;
486 }
487
488 static int add_listen_socket(Context *context, sd_event *event, int fd) {
489         sd_event_source *source;
490         int r;
491
492         assert(context);
493         assert(event);
494         assert(fd >= 0);
495
496         log_info("Listening on %i", fd);
497
498         r = set_ensure_allocated(&context->listen, trivial_hash_func, trivial_compare_func);
499         if (r < 0) {
500                 log_oom();
501                 return r;
502         }
503
504         r = sd_is_socket(fd, 0, SOCK_STREAM, 1);
505         if (r < 0) {
506                 log_error("Failed to determine socket type: %s", strerror(-r));
507                 return r;
508         }
509         if (r == 0) {
510                 log_error("Passed in socket is not a stream socket.");
511                 return -EINVAL;
512         }
513
514         r = fd_nonblock(fd, true);
515         if (r < 0) {
516                 log_error("Failed to mark file descriptor non-blocking: %s", strerror(-r));
517                 return r;
518         }
519
520         r = sd_event_add_io(event, fd, EPOLLIN, accept_cb, context, &source);
521         if (r < 0) {
522                 log_error("Failed to add event source: %s", strerror(-r));
523                 return r;
524         }
525
526         r = set_put(context->listen, source);
527         if (r < 0) {
528                 log_error("Failed to add source to set: %s", strerror(-r));
529                 sd_event_source_unref(source);
530                 return r;
531         }
532
533         /* Set the watcher to oneshot in case other processes are also
534          * watching to accept(). */
535         r = sd_event_source_set_enabled(source, SD_EVENT_ONESHOT);
536         if (r < 0) {
537                 log_error("Failed to enable oneshot mode: %s", strerror(-r));
538                 return r;
539         }
540
541         return 0;
542 }
543
544 static int help(void) {
545
546         printf("%s [HOST:PORT]\n"
547                "%s [SOCKET]\n\n"
548                "Bidirectionally proxy local sockets to another (possibly remote) socket.\n\n"
549                "  -h --help              Show this help\n"
550                "     --version           Show package version\n",
551                program_invocation_short_name,
552                program_invocation_short_name);
553
554         return 0;
555 }
556
557 static int parse_argv(int argc, char *argv[]) {
558
559         enum {
560                 ARG_VERSION = 0x100,
561                 ARG_IGNORE_ENV
562         };
563
564         static const struct option options[] = {
565                 { "help",       no_argument, NULL, 'h'           },
566                 { "version",    no_argument, NULL, ARG_VERSION   },
567                 {}
568         };
569
570         int c;
571
572         assert(argc >= 0);
573         assert(argv);
574
575         while ((c = getopt_long(argc, argv, "h", options, NULL)) >= 0) {
576
577                 switch (c) {
578
579                 case 'h':
580                         return help();
581
582                 case ARG_VERSION:
583                         puts(PACKAGE_STRING);
584                         puts(SYSTEMD_FEATURES);
585                         return 0;
586
587                 case '?':
588                         return -EINVAL;
589
590                 default:
591                         assert_not_reached("Unhandled option");
592                 }
593         }
594
595         if (optind >= argc) {
596                 log_error("Not enough parameters.");
597                 return -EINVAL;
598         }
599
600         if (argc != optind+1) {
601                 log_error("Too many parameters.");
602                 return -EINVAL;
603         }
604
605         arg_remote_host = argv[optind];
606         return 1;
607 }
608
609 int main(int argc, char *argv[]) {
610         _cleanup_event_unref_ sd_event *event = NULL;
611         Context context = {};
612         int r, n, fd;
613
614         log_parse_environment();
615         log_open();
616
617         r = parse_argv(argc, argv);
618         if (r <= 0)
619                 goto finish;
620
621         r = sd_event_new(&event);
622         if (r < 0) {
623                 log_error("Failed to allocate event loop: %s", strerror(-r));
624                 goto finish;
625         }
626
627         n = sd_listen_fds(1);
628         if (n < 0) {
629                 log_error("Failed to receive sockets from parent.");
630                 r = n;
631                 goto finish;
632         } else if (n == 0) {
633                 log_error("Didn't get any sockets passed in.");
634                 r = -EINVAL;
635                 goto finish;
636         }
637
638         for (fd = SD_LISTEN_FDS_START; fd < SD_LISTEN_FDS_START + n; fd++) {
639                 r = add_listen_socket(&context, event, fd);
640                 if (r < 0)
641                         goto finish;
642         }
643
644         r = sd_event_loop(event);
645         if (r < 0) {
646                 log_error("Failed to run event loop: %s", strerror(-r));
647                 goto finish;
648         }
649
650 finish:
651         context_free(&context);
652
653         return r < 0 ? EXIT_FAILURE : EXIT_SUCCESS;
654 }