chiark / gitweb /
sd-rtnl: fix bogus warning about dropping 20 bytes from multi-part messages
[elogind.git] / src / socket-proxy / socket-proxyd.c
1 /*-*- Mode: C; c-basic-offset: 8; indent-tabs-mode: nil -*-*/
2
3 /***
4   This file is part of systemd.
5
6   Copyright 2013 David Strauss
7
8   systemd is free software; you can redistribute it and/or modify it
9   under the terms of the GNU Lesser General Public License as published by
10   the Free Software Foundation; either version 2.1 of the License, or
11   (at your option) any later version.
12
13   systemd is distributed in the hope that it will be useful, but
14   WITHOUT ANY WARRANTY; without even the implied warranty of
15   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
16   Lesser General Public License for more details.
17
18   You should have received a copy of the GNU Lesser General Public License
19   along with systemd; If not, see <http://www.gnu.org/licenses/>.
20  ***/
21
22 #include <arpa/inet.h>
23 #include <errno.h>
24 #include <getopt.h>
25 #include <stdio.h>
26 #include <stdlib.h>
27 #include <string.h>
28 #include <netdb.h>
29 #include <fcntl.h>
30 #include <sys/socket.h>
31 #include <sys/un.h>
32 #include <unistd.h>
33
34 #include "sd-daemon.h"
35 #include "sd-event.h"
36 #include "sd-resolve.h"
37 #include "log.h"
38 #include "socket-util.h"
39 #include "util.h"
40 #include "event-util.h"
41 #include "build.h"
42 #include "set.h"
43 #include "path-util.h"
44
45 #define BUFFER_SIZE (256 * 1024)
46 #define CONNECTIONS_MAX 256
47
48 static const char *arg_remote_host = NULL;
49
50 typedef struct Context {
51         sd_event *event;
52         sd_resolve *resolve;
53
54         Set *listen;
55         Set *connections;
56 } Context;
57
58 typedef struct Connection {
59         Context *context;
60
61         int server_fd, client_fd;
62         int server_to_client_buffer[2]; /* a pipe */
63         int client_to_server_buffer[2]; /* a pipe */
64
65         size_t server_to_client_buffer_full, client_to_server_buffer_full;
66         size_t server_to_client_buffer_size, client_to_server_buffer_size;
67
68         sd_event_source *server_event_source, *client_event_source;
69
70         sd_resolve_query *resolve_query;
71 } Connection;
72
73 static void connection_free(Connection *c) {
74         assert(c);
75
76         if (c->context)
77                 set_remove(c->context->connections, c);
78
79         sd_event_source_unref(c->server_event_source);
80         sd_event_source_unref(c->client_event_source);
81
82         safe_close(c->server_fd);
83         safe_close(c->client_fd);
84
85         safe_close_pair(c->server_to_client_buffer);
86         safe_close_pair(c->client_to_server_buffer);
87
88         sd_resolve_query_unref(c->resolve_query);
89
90         free(c);
91 }
92
93 static void context_free(Context *context) {
94         sd_event_source *es;
95         Connection *c;
96
97         assert(context);
98
99         while ((es = set_steal_first(context->listen)))
100                 sd_event_source_unref(es);
101
102         while ((c = set_first(context->connections)))
103                 connection_free(c);
104
105         set_free(context->listen);
106         set_free(context->connections);
107
108         sd_event_unref(context->event);
109         sd_resolve_unref(context->resolve);
110 }
111
112 static int connection_create_pipes(Connection *c, int buffer[2], size_t *sz) {
113         int r;
114
115         assert(c);
116         assert(buffer);
117         assert(sz);
118
119         if (buffer[0] >= 0)
120                 return 0;
121
122         r = pipe2(buffer, O_CLOEXEC|O_NONBLOCK);
123         if (r < 0)
124                 return log_error_errno(errno, "Failed to allocate pipe buffer: %m");
125
126         (void) fcntl(buffer[0], F_SETPIPE_SZ, BUFFER_SIZE);
127
128         r = fcntl(buffer[0], F_GETPIPE_SZ);
129         if (r < 0)
130                 return log_error_errno(errno, "Failed to get pipe buffer size: %m");
131
132         assert(r > 0);
133         *sz = r;
134
135         return 0;
136 }
137
138 static int connection_shovel(
139                 Connection *c,
140                 int *from, int buffer[2], int *to,
141                 size_t *full, size_t *sz,
142                 sd_event_source **from_source, sd_event_source **to_source) {
143
144         bool shoveled;
145
146         assert(c);
147         assert(from);
148         assert(buffer);
149         assert(buffer[0] >= 0);
150         assert(buffer[1] >= 0);
151         assert(to);
152         assert(full);
153         assert(sz);
154         assert(from_source);
155         assert(to_source);
156
157         do {
158                 ssize_t z;
159
160                 shoveled = false;
161
162                 if (*full < *sz && *from >= 0 && *to >= 0) {
163                         z = splice(*from, NULL, buffer[1], NULL, *sz - *full, SPLICE_F_MOVE|SPLICE_F_NONBLOCK);
164                         if (z > 0) {
165                                 *full += z;
166                                 shoveled = true;
167                         } else if (z == 0 || errno == EPIPE || errno == ECONNRESET) {
168                                 *from_source = sd_event_source_unref(*from_source);
169                                 *from = safe_close(*from);
170                         } else if (errno != EAGAIN && errno != EINTR)
171                                 return log_error_errno(errno, "Failed to splice: %m");
172                 }
173
174                 if (*full > 0 && *to >= 0) {
175                         z = splice(buffer[0], NULL, *to, NULL, *full, SPLICE_F_MOVE|SPLICE_F_NONBLOCK);
176                         if (z > 0) {
177                                 *full -= z;
178                                 shoveled = true;
179                         } else if (z == 0 || errno == EPIPE || errno == ECONNRESET) {
180                                 *to_source = sd_event_source_unref(*to_source);
181                                 *to = safe_close(*to);
182                         } else if (errno != EAGAIN && errno != EINTR)
183                                 return log_error_errno(errno, "Failed to splice: %m");
184                 }
185         } while (shoveled);
186
187         return 0;
188 }
189
190 static int connection_enable_event_sources(Connection *c);
191
192 static int traffic_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) {
193         Connection *c = userdata;
194         int r;
195
196         assert(s);
197         assert(fd >= 0);
198         assert(c);
199
200         r = connection_shovel(c,
201                               &c->server_fd, c->server_to_client_buffer, &c->client_fd,
202                               &c->server_to_client_buffer_full, &c->server_to_client_buffer_size,
203                               &c->server_event_source, &c->client_event_source);
204         if (r < 0)
205                 goto quit;
206
207         r = connection_shovel(c,
208                               &c->client_fd, c->client_to_server_buffer, &c->server_fd,
209                               &c->client_to_server_buffer_full, &c->client_to_server_buffer_size,
210                               &c->client_event_source, &c->server_event_source);
211         if (r < 0)
212                 goto quit;
213
214         /* EOF on both sides? */
215         if (c->server_fd == -1 && c->client_fd == -1)
216                 goto quit;
217
218         /* Server closed, and all data written to client? */
219         if (c->server_fd == -1 && c->server_to_client_buffer_full <= 0)
220                 goto quit;
221
222         /* Client closed, and all data written to server? */
223         if (c->client_fd == -1 && c->client_to_server_buffer_full <= 0)
224                 goto quit;
225
226         r = connection_enable_event_sources(c);
227         if (r < 0)
228                 goto quit;
229
230         return 1;
231
232 quit:
233         connection_free(c);
234         return 0; /* ignore errors, continue serving */
235 }
236
237 static int connection_enable_event_sources(Connection *c) {
238         uint32_t a = 0, b = 0;
239         int r;
240
241         assert(c);
242
243         if (c->server_to_client_buffer_full > 0)
244                 b |= EPOLLOUT;
245         if (c->server_to_client_buffer_full < c->server_to_client_buffer_size)
246                 a |= EPOLLIN;
247
248         if (c->client_to_server_buffer_full > 0)
249                 a |= EPOLLOUT;
250         if (c->client_to_server_buffer_full < c->client_to_server_buffer_size)
251                 b |= EPOLLIN;
252
253         if (c->server_event_source)
254                 r = sd_event_source_set_io_events(c->server_event_source, a);
255         else if (c->server_fd >= 0)
256                 r = sd_event_add_io(c->context->event, &c->server_event_source, c->server_fd, a, traffic_cb, c);
257         else
258                 r = 0;
259
260         if (r < 0)
261                 return log_error_errno(r, "Failed to set up server event source: %m");
262
263         if (c->client_event_source)
264                 r = sd_event_source_set_io_events(c->client_event_source, b);
265         else if (c->client_fd >= 0)
266                 r = sd_event_add_io(c->context->event, &c->client_event_source, c->client_fd, b, traffic_cb, c);
267         else
268                 r = 0;
269
270         if (r < 0)
271                 return log_error_errno(r, "Failed to set up client event source: %m");
272
273         return 0;
274 }
275
276 static int connection_complete(Connection *c) {
277         int r;
278
279         assert(c);
280
281         r = connection_create_pipes(c, c->server_to_client_buffer, &c->server_to_client_buffer_size);
282         if (r < 0)
283                 goto fail;
284
285         r = connection_create_pipes(c, c->client_to_server_buffer, &c->client_to_server_buffer_size);
286         if (r < 0)
287                 goto fail;
288
289         r = connection_enable_event_sources(c);
290         if (r < 0)
291                 goto fail;
292
293         return 0;
294
295 fail:
296         connection_free(c);
297         return 0; /* ignore errors, continue serving */
298 }
299
300 static int connect_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) {
301         Connection *c = userdata;
302         socklen_t solen;
303         int error, r;
304
305         assert(s);
306         assert(fd >= 0);
307         assert(c);
308
309         solen = sizeof(error);
310         r = getsockopt(fd, SOL_SOCKET, SO_ERROR, &error, &solen);
311         if (r < 0) {
312                 log_error_errno(errno, "Failed to issue SO_ERROR: %m");
313                 goto fail;
314         }
315
316         if (error != 0) {
317                 log_error_errno(error, "Failed to connect to remote host: %m");
318                 goto fail;
319         }
320
321         c->client_event_source = sd_event_source_unref(c->client_event_source);
322
323         return connection_complete(c);
324
325 fail:
326         connection_free(c);
327         return 0; /* ignore errors, continue serving */
328 }
329
330 static int connection_start(Connection *c, struct sockaddr *sa, socklen_t salen) {
331         int r;
332
333         assert(c);
334         assert(sa);
335         assert(salen);
336
337         c->client_fd = socket(sa->sa_family, SOCK_STREAM|SOCK_NONBLOCK|SOCK_CLOEXEC, 0);
338         if (c->client_fd < 0) {
339                 log_error_errno(errno, "Failed to get remote socket: %m");
340                 goto fail;
341         }
342
343         r = connect(c->client_fd, sa, salen);
344         if (r < 0) {
345                 if (errno == EINPROGRESS) {
346                         r = sd_event_add_io(c->context->event, &c->client_event_source, c->client_fd, EPOLLOUT, connect_cb, c);
347                         if (r < 0) {
348                                 log_error_errno(r, "Failed to add connection socket: %m");
349                                 goto fail;
350                         }
351
352                         r = sd_event_source_set_enabled(c->client_event_source, SD_EVENT_ONESHOT);
353                         if (r < 0) {
354                                 log_error_errno(r, "Failed to enable oneshot event source: %m");
355                                 goto fail;
356                         }
357                 } else {
358                         log_error_errno(errno, "Failed to connect to remote host: %m");
359                         goto fail;
360                 }
361         } else {
362                 r = connection_complete(c);
363                 if (r < 0)
364                         goto fail;
365         }
366
367         return 0;
368
369 fail:
370         connection_free(c);
371         return 0; /* ignore errors, continue serving */
372 }
373
374 static int resolve_cb(sd_resolve_query *q, int ret, const struct addrinfo *ai, void *userdata) {
375         Connection *c = userdata;
376
377         assert(q);
378         assert(c);
379
380         if (ret != 0) {
381                 log_error("Failed to resolve host: %s", gai_strerror(ret));
382                 goto fail;
383         }
384
385         c->resolve_query = sd_resolve_query_unref(c->resolve_query);
386
387         return connection_start(c, ai->ai_addr, ai->ai_addrlen);
388
389 fail:
390         connection_free(c);
391         return 0; /* ignore errors, continue serving */
392 }
393
394 static int resolve_remote(Connection *c) {
395
396         static const struct addrinfo hints = {
397                 .ai_family = AF_UNSPEC,
398                 .ai_socktype = SOCK_STREAM,
399                 .ai_flags = AI_ADDRCONFIG
400         };
401
402         union sockaddr_union sa = {};
403         const char *node, *service;
404         socklen_t salen;
405         int r;
406
407         if (path_is_absolute(arg_remote_host)) {
408                 sa.un.sun_family = AF_UNIX;
409                 strncpy(sa.un.sun_path, arg_remote_host, sizeof(sa.un.sun_path)-1);
410                 sa.un.sun_path[sizeof(sa.un.sun_path)-1] = 0;
411
412                 salen = offsetof(union sockaddr_union, un.sun_path) + strlen(sa.un.sun_path);
413
414                 return connection_start(c, &sa.sa, salen);
415         }
416
417         if (arg_remote_host[0] == '@') {
418                 sa.un.sun_family = AF_UNIX;
419                 sa.un.sun_path[0] = 0;
420                 strncpy(sa.un.sun_path+1, arg_remote_host+1, sizeof(sa.un.sun_path)-2);
421                 sa.un.sun_path[sizeof(sa.un.sun_path)-1] = 0;
422
423                 salen = offsetof(union sockaddr_union, un.sun_path) + 1 + strlen(sa.un.sun_path + 1);
424
425                 return connection_start(c, &sa.sa, salen);
426         }
427
428         service = strrchr(arg_remote_host, ':');
429         if (service) {
430                 node = strndupa(arg_remote_host, service - arg_remote_host);
431                 service ++;
432         } else {
433                 node = arg_remote_host;
434                 service = "80";
435         }
436
437         log_debug("Looking up address info for %s:%s", node, service);
438         r = sd_resolve_getaddrinfo(c->context->resolve, &c->resolve_query, node, service, &hints, resolve_cb, c);
439         if (r < 0) {
440                 log_error_errno(r, "Failed to resolve remote host: %m");
441                 goto fail;
442         }
443
444         return 0;
445
446 fail:
447         connection_free(c);
448         return 0; /* ignore errors, continue serving */
449 }
450
451 static int add_connection_socket(Context *context, int fd) {
452         Connection *c;
453         int r;
454
455         assert(context);
456         assert(fd >= 0);
457
458         if (set_size(context->connections) > CONNECTIONS_MAX) {
459                 log_warning("Hit connection limit, refusing connection.");
460                 safe_close(fd);
461                 return 0;
462         }
463
464         r = set_ensure_allocated(&context->connections, NULL);
465         if (r < 0) {
466                 log_oom();
467                 return 0;
468         }
469
470         c = new0(Connection, 1);
471         if (!c) {
472                 log_oom();
473                 return 0;
474         }
475
476         c->context = context;
477         c->server_fd = fd;
478         c->client_fd = -1;
479         c->server_to_client_buffer[0] = c->server_to_client_buffer[1] = -1;
480         c->client_to_server_buffer[0] = c->client_to_server_buffer[1] = -1;
481
482         r = set_put(context->connections, c);
483         if (r < 0) {
484                 free(c);
485                 log_oom();
486                 return 0;
487         }
488
489         return resolve_remote(c);
490 }
491
492 static int accept_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) {
493         _cleanup_free_ char *peer = NULL;
494         Context *context = userdata;
495         int nfd = -1, r;
496
497         assert(s);
498         assert(fd >= 0);
499         assert(revents & EPOLLIN);
500         assert(context);
501
502         nfd = accept4(fd, NULL, NULL, SOCK_NONBLOCK|SOCK_CLOEXEC);
503         if (nfd < 0) {
504                 if (errno != -EAGAIN)
505                         log_warning_errno(errno, "Failed to accept() socket: %m");
506         } else {
507                 getpeername_pretty(nfd, &peer);
508                 log_debug("New connection from %s", strna(peer));
509
510                 r = add_connection_socket(context, nfd);
511                 if (r < 0) {
512                         log_error_errno(r, "Failed to accept connection, ignoring: %m");
513                         safe_close(fd);
514                 }
515         }
516
517         r = sd_event_source_set_enabled(s, SD_EVENT_ONESHOT);
518         if (r < 0) {
519                 log_error_errno(r, "Error while re-enabling listener with ONESHOT: %m");
520                 sd_event_exit(context->event, r);
521                 return r;
522         }
523
524         return 1;
525 }
526
527 static int add_listen_socket(Context *context, int fd) {
528         sd_event_source *source;
529         int r;
530
531         assert(context);
532         assert(fd >= 0);
533
534         r = set_ensure_allocated(&context->listen, NULL);
535         if (r < 0) {
536                 log_oom();
537                 return r;
538         }
539
540         r = sd_is_socket(fd, 0, SOCK_STREAM, 1);
541         if (r < 0)
542                 return log_error_errno(r, "Failed to determine socket type: %m");
543         if (r == 0) {
544                 log_error("Passed in socket is not a stream socket.");
545                 return -EINVAL;
546         }
547
548         r = fd_nonblock(fd, true);
549         if (r < 0)
550                 return log_error_errno(r, "Failed to mark file descriptor non-blocking: %m");
551
552         r = sd_event_add_io(context->event, &source, fd, EPOLLIN, accept_cb, context);
553         if (r < 0)
554                 return log_error_errno(r, "Failed to add event source: %m");
555
556         r = set_put(context->listen, source);
557         if (r < 0) {
558                 log_error_errno(r, "Failed to add source to set: %m");
559                 sd_event_source_unref(source);
560                 return r;
561         }
562
563         /* Set the watcher to oneshot in case other processes are also
564          * watching to accept(). */
565         r = sd_event_source_set_enabled(source, SD_EVENT_ONESHOT);
566         if (r < 0)
567                 return log_error_errno(r, "Failed to enable oneshot mode: %m");
568
569         return 0;
570 }
571
572 static void help(void) {
573         printf("%1$s [HOST:PORT]\n"
574                "%1$s [SOCKET]\n\n"
575                "Bidirectionally proxy local sockets to another (possibly remote) socket.\n\n"
576                "  -h --help              Show this help\n"
577                "     --version           Show package version\n",
578                program_invocation_short_name);
579 }
580
581 static int parse_argv(int argc, char *argv[]) {
582
583         enum {
584                 ARG_VERSION = 0x100,
585                 ARG_IGNORE_ENV
586         };
587
588         static const struct option options[] = {
589                 { "help",       no_argument, NULL, 'h'           },
590                 { "version",    no_argument, NULL, ARG_VERSION   },
591                 {}
592         };
593
594         int c;
595
596         assert(argc >= 0);
597         assert(argv);
598
599         while ((c = getopt_long(argc, argv, "h", options, NULL)) >= 0)
600
601                 switch (c) {
602
603                 case 'h':
604                         help();
605                         return 0;
606
607                 case ARG_VERSION:
608                         puts(PACKAGE_STRING);
609                         puts(SYSTEMD_FEATURES);
610                         return 0;
611
612                 case '?':
613                         return -EINVAL;
614
615                 default:
616                         assert_not_reached("Unhandled option");
617                 }
618
619         if (optind >= argc) {
620                 log_error("Not enough parameters.");
621                 return -EINVAL;
622         }
623
624         if (argc != optind+1) {
625                 log_error("Too many parameters.");
626                 return -EINVAL;
627         }
628
629         arg_remote_host = argv[optind];
630         return 1;
631 }
632
633 int main(int argc, char *argv[]) {
634         Context context = {};
635         int r, n, fd;
636
637         log_parse_environment();
638         log_open();
639
640         r = parse_argv(argc, argv);
641         if (r <= 0)
642                 goto finish;
643
644         r = sd_event_default(&context.event);
645         if (r < 0) {
646                 log_error_errno(r, "Failed to allocate event loop: %m");
647                 goto finish;
648         }
649
650         r = sd_resolve_default(&context.resolve);
651         if (r < 0) {
652                 log_error_errno(r, "Failed to allocate resolver: %m");
653                 goto finish;
654         }
655
656         r = sd_resolve_attach_event(context.resolve, context.event, 0);
657         if (r < 0) {
658                 log_error_errno(r, "Failed to attach resolver: %m");
659                 goto finish;
660         }
661
662         sd_event_set_watchdog(context.event, true);
663
664         n = sd_listen_fds(1);
665         if (n < 0) {
666                 log_error("Failed to receive sockets from parent.");
667                 r = n;
668                 goto finish;
669         } else if (n == 0) {
670                 log_error("Didn't get any sockets passed in.");
671                 r = -EINVAL;
672                 goto finish;
673         }
674
675         for (fd = SD_LISTEN_FDS_START; fd < SD_LISTEN_FDS_START + n; fd++) {
676                 r = add_listen_socket(&context, fd);
677                 if (r < 0)
678                         goto finish;
679         }
680
681         r = sd_event_loop(context.event);
682         if (r < 0) {
683                 log_error_errno(r, "Failed to run event loop: %m");
684                 goto finish;
685         }
686
687 finish:
688         context_free(&context);
689
690         return r < 0 ? EXIT_FAILURE : EXIT_SUCCESS;
691 }