chiark / gitweb /
treewide: use log_*_errno whenever %m is in the format string
[elogind.git] / src / socket-proxy / socket-proxyd.c
1 /*-*- Mode: C; c-basic-offset: 8; indent-tabs-mode: nil -*-*/
2
3 /***
4   This file is part of systemd.
5
6   Copyright 2013 David Strauss
7
8   systemd is free software; you can redistribute it and/or modify it
9   under the terms of the GNU Lesser General Public License as published by
10   the Free Software Foundation; either version 2.1 of the License, or
11   (at your option) any later version.
12
13   systemd is distributed in the hope that it will be useful, but
14   WITHOUT ANY WARRANTY; without even the implied warranty of
15   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
16   Lesser General Public License for more details.
17
18   You should have received a copy of the GNU Lesser General Public License
19   along with systemd; If not, see <http://www.gnu.org/licenses/>.
20  ***/
21
22 #include <arpa/inet.h>
23 #include <errno.h>
24 #include <getopt.h>
25 #include <stdio.h>
26 #include <stdlib.h>
27 #include <string.h>
28 #include <netdb.h>
29 #include <fcntl.h>
30 #include <sys/socket.h>
31 #include <sys/un.h>
32 #include <unistd.h>
33
34 #include "sd-daemon.h"
35 #include "sd-event.h"
36 #include "sd-resolve.h"
37 #include "log.h"
38 #include "socket-util.h"
39 #include "util.h"
40 #include "event-util.h"
41 #include "build.h"
42 #include "set.h"
43 #include "path-util.h"
44
45 #define BUFFER_SIZE (256 * 1024)
46 #define CONNECTIONS_MAX 256
47
48 static const char *arg_remote_host = NULL;
49
50 typedef struct Context {
51         sd_event *event;
52         sd_resolve *resolve;
53
54         Set *listen;
55         Set *connections;
56 } Context;
57
58 typedef struct Connection {
59         Context *context;
60
61         int server_fd, client_fd;
62         int server_to_client_buffer[2]; /* a pipe */
63         int client_to_server_buffer[2]; /* a pipe */
64
65         size_t server_to_client_buffer_full, client_to_server_buffer_full;
66         size_t server_to_client_buffer_size, client_to_server_buffer_size;
67
68         sd_event_source *server_event_source, *client_event_source;
69
70         sd_resolve_query *resolve_query;
71 } Connection;
72
73 static void connection_free(Connection *c) {
74         assert(c);
75
76         if (c->context)
77                 set_remove(c->context->connections, c);
78
79         sd_event_source_unref(c->server_event_source);
80         sd_event_source_unref(c->client_event_source);
81
82         safe_close(c->server_fd);
83         safe_close(c->client_fd);
84
85         safe_close_pair(c->server_to_client_buffer);
86         safe_close_pair(c->client_to_server_buffer);
87
88         sd_resolve_query_unref(c->resolve_query);
89
90         free(c);
91 }
92
93 static void context_free(Context *context) {
94         sd_event_source *es;
95         Connection *c;
96
97         assert(context);
98
99         while ((es = set_steal_first(context->listen)))
100                 sd_event_source_unref(es);
101
102         while ((c = set_first(context->connections)))
103                 connection_free(c);
104
105         set_free(context->listen);
106         set_free(context->connections);
107
108         sd_event_unref(context->event);
109         sd_resolve_unref(context->resolve);
110 }
111
112 static int connection_create_pipes(Connection *c, int buffer[2], size_t *sz) {
113         int r;
114
115         assert(c);
116         assert(buffer);
117         assert(sz);
118
119         if (buffer[0] >= 0)
120                 return 0;
121
122         r = pipe2(buffer, O_CLOEXEC|O_NONBLOCK);
123         if (r < 0) {
124                 log_error_errno(errno, "Failed to allocate pipe buffer: %m");
125                 return -errno;
126         }
127
128         (void) fcntl(buffer[0], F_SETPIPE_SZ, BUFFER_SIZE);
129
130         r = fcntl(buffer[0], F_GETPIPE_SZ);
131         if (r < 0) {
132                 log_error_errno(errno, "Failed to get pipe buffer size: %m");
133                 return -errno;
134         }
135
136         assert(r > 0);
137         *sz = r;
138
139         return 0;
140 }
141
142 static int connection_shovel(
143                 Connection *c,
144                 int *from, int buffer[2], int *to,
145                 size_t *full, size_t *sz,
146                 sd_event_source **from_source, sd_event_source **to_source) {
147
148         bool shoveled;
149
150         assert(c);
151         assert(from);
152         assert(buffer);
153         assert(buffer[0] >= 0);
154         assert(buffer[1] >= 0);
155         assert(to);
156         assert(full);
157         assert(sz);
158         assert(from_source);
159         assert(to_source);
160
161         do {
162                 ssize_t z;
163
164                 shoveled = false;
165
166                 if (*full < *sz && *from >= 0 && *to >= 0) {
167                         z = splice(*from, NULL, buffer[1], NULL, *sz - *full, SPLICE_F_MOVE|SPLICE_F_NONBLOCK);
168                         if (z > 0) {
169                                 *full += z;
170                                 shoveled = true;
171                         } else if (z == 0 || errno == EPIPE || errno == ECONNRESET) {
172                                 *from_source = sd_event_source_unref(*from_source);
173                                 *from = safe_close(*from);
174                         } else if (errno != EAGAIN && errno != EINTR) {
175                                 log_error_errno(errno, "Failed to splice: %m");
176                                 return -errno;
177                         }
178                 }
179
180                 if (*full > 0 && *to >= 0) {
181                         z = splice(buffer[0], NULL, *to, NULL, *full, SPLICE_F_MOVE|SPLICE_F_NONBLOCK);
182                         if (z > 0) {
183                                 *full -= z;
184                                 shoveled = true;
185                         } else if (z == 0 || errno == EPIPE || errno == ECONNRESET) {
186                                 *to_source = sd_event_source_unref(*to_source);
187                                 *to = safe_close(*to);
188                         } else if (errno != EAGAIN && errno != EINTR) {
189                                 log_error_errno(errno, "Failed to splice: %m");
190                                 return -errno;
191                         }
192                 }
193         } while (shoveled);
194
195         return 0;
196 }
197
198 static int connection_enable_event_sources(Connection *c);
199
200 static int traffic_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) {
201         Connection *c = userdata;
202         int r;
203
204         assert(s);
205         assert(fd >= 0);
206         assert(c);
207
208         r = connection_shovel(c,
209                               &c->server_fd, c->server_to_client_buffer, &c->client_fd,
210                               &c->server_to_client_buffer_full, &c->server_to_client_buffer_size,
211                               &c->server_event_source, &c->client_event_source);
212         if (r < 0)
213                 goto quit;
214
215         r = connection_shovel(c,
216                               &c->client_fd, c->client_to_server_buffer, &c->server_fd,
217                               &c->client_to_server_buffer_full, &c->client_to_server_buffer_size,
218                               &c->client_event_source, &c->server_event_source);
219         if (r < 0)
220                 goto quit;
221
222         /* EOF on both sides? */
223         if (c->server_fd == -1 && c->client_fd == -1)
224                 goto quit;
225
226         /* Server closed, and all data written to client? */
227         if (c->server_fd == -1 && c->server_to_client_buffer_full <= 0)
228                 goto quit;
229
230         /* Client closed, and all data written to server? */
231         if (c->client_fd == -1 && c->client_to_server_buffer_full <= 0)
232                 goto quit;
233
234         r = connection_enable_event_sources(c);
235         if (r < 0)
236                 goto quit;
237
238         return 1;
239
240 quit:
241         connection_free(c);
242         return 0; /* ignore errors, continue serving */
243 }
244
245 static int connection_enable_event_sources(Connection *c) {
246         uint32_t a = 0, b = 0;
247         int r;
248
249         assert(c);
250
251         if (c->server_to_client_buffer_full > 0)
252                 b |= EPOLLOUT;
253         if (c->server_to_client_buffer_full < c->server_to_client_buffer_size)
254                 a |= EPOLLIN;
255
256         if (c->client_to_server_buffer_full > 0)
257                 a |= EPOLLOUT;
258         if (c->client_to_server_buffer_full < c->client_to_server_buffer_size)
259                 b |= EPOLLIN;
260
261         if (c->server_event_source)
262                 r = sd_event_source_set_io_events(c->server_event_source, a);
263         else if (c->server_fd >= 0)
264                 r = sd_event_add_io(c->context->event, &c->server_event_source, c->server_fd, a, traffic_cb, c);
265         else
266                 r = 0;
267
268         if (r < 0)
269                 return log_error_errno(r, "Failed to set up server event source: %m");
270
271         if (c->client_event_source)
272                 r = sd_event_source_set_io_events(c->client_event_source, b);
273         else if (c->client_fd >= 0)
274                 r = sd_event_add_io(c->context->event, &c->client_event_source, c->client_fd, b, traffic_cb, c);
275         else
276                 r = 0;
277
278         if (r < 0)
279                 return log_error_errno(r, "Failed to set up client event source: %m");
280
281         return 0;
282 }
283
284 static int connection_complete(Connection *c) {
285         int r;
286
287         assert(c);
288
289         r = connection_create_pipes(c, c->server_to_client_buffer, &c->server_to_client_buffer_size);
290         if (r < 0)
291                 goto fail;
292
293         r = connection_create_pipes(c, c->client_to_server_buffer, &c->client_to_server_buffer_size);
294         if (r < 0)
295                 goto fail;
296
297         r = connection_enable_event_sources(c);
298         if (r < 0)
299                 goto fail;
300
301         return 0;
302
303 fail:
304         connection_free(c);
305         return 0; /* ignore errors, continue serving */
306 }
307
308 static int connect_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) {
309         Connection *c = userdata;
310         socklen_t solen;
311         int error, r;
312
313         assert(s);
314         assert(fd >= 0);
315         assert(c);
316
317         solen = sizeof(error);
318         r = getsockopt(fd, SOL_SOCKET, SO_ERROR, &error, &solen);
319         if (r < 0) {
320                 log_error_errno(errno, "Failed to issue SO_ERROR: %m");
321                 goto fail;
322         }
323
324         if (error != 0) {
325                 log_error_errno(error, "Failed to connect to remote host: %m");
326                 goto fail;
327         }
328
329         c->client_event_source = sd_event_source_unref(c->client_event_source);
330
331         return connection_complete(c);
332
333 fail:
334         connection_free(c);
335         return 0; /* ignore errors, continue serving */
336 }
337
338 static int connection_start(Connection *c, struct sockaddr *sa, socklen_t salen) {
339         int r;
340
341         assert(c);
342         assert(sa);
343         assert(salen);
344
345         c->client_fd = socket(sa->sa_family, SOCK_STREAM|SOCK_NONBLOCK|SOCK_CLOEXEC, 0);
346         if (c->client_fd < 0) {
347                 log_error_errno(errno, "Failed to get remote socket: %m");
348                 goto fail;
349         }
350
351         r = connect(c->client_fd, sa, salen);
352         if (r < 0) {
353                 if (errno == EINPROGRESS) {
354                         r = sd_event_add_io(c->context->event, &c->client_event_source, c->client_fd, EPOLLOUT, connect_cb, c);
355                         if (r < 0) {
356                                 log_error_errno(r, "Failed to add connection socket: %m");
357                                 goto fail;
358                         }
359
360                         r = sd_event_source_set_enabled(c->client_event_source, SD_EVENT_ONESHOT);
361                         if (r < 0) {
362                                 log_error_errno(r, "Failed to enable oneshot event source: %m");
363                                 goto fail;
364                         }
365                 } else {
366                         log_error_errno(errno, "Failed to connect to remote host: %m");
367                         goto fail;
368                 }
369         } else {
370                 r = connection_complete(c);
371                 if (r < 0)
372                         goto fail;
373         }
374
375         return 0;
376
377 fail:
378         connection_free(c);
379         return 0; /* ignore errors, continue serving */
380 }
381
382 static int resolve_cb(sd_resolve_query *q, int ret, const struct addrinfo *ai, void *userdata) {
383         Connection *c = userdata;
384
385         assert(q);
386         assert(c);
387
388         if (ret != 0) {
389                 log_error("Failed to resolve host: %s", gai_strerror(ret));
390                 goto fail;
391         }
392
393         c->resolve_query = sd_resolve_query_unref(c->resolve_query);
394
395         return connection_start(c, ai->ai_addr, ai->ai_addrlen);
396
397 fail:
398         connection_free(c);
399         return 0; /* ignore errors, continue serving */
400 }
401
402 static int resolve_remote(Connection *c) {
403
404         static const struct addrinfo hints = {
405                 .ai_family = AF_UNSPEC,
406                 .ai_socktype = SOCK_STREAM,
407                 .ai_flags = AI_ADDRCONFIG
408         };
409
410         union sockaddr_union sa = {};
411         const char *node, *service;
412         socklen_t salen;
413         int r;
414
415         if (path_is_absolute(arg_remote_host)) {
416                 sa.un.sun_family = AF_UNIX;
417                 strncpy(sa.un.sun_path, arg_remote_host, sizeof(sa.un.sun_path)-1);
418                 sa.un.sun_path[sizeof(sa.un.sun_path)-1] = 0;
419
420                 salen = offsetof(union sockaddr_union, un.sun_path) + strlen(sa.un.sun_path);
421
422                 return connection_start(c, &sa.sa, salen);
423         }
424
425         if (arg_remote_host[0] == '@') {
426                 sa.un.sun_family = AF_UNIX;
427                 sa.un.sun_path[0] = 0;
428                 strncpy(sa.un.sun_path+1, arg_remote_host+1, sizeof(sa.un.sun_path)-2);
429                 sa.un.sun_path[sizeof(sa.un.sun_path)-1] = 0;
430
431                 salen = offsetof(union sockaddr_union, un.sun_path) + 1 + strlen(sa.un.sun_path + 1);
432
433                 return connection_start(c, &sa.sa, salen);
434         }
435
436         service = strrchr(arg_remote_host, ':');
437         if (service) {
438                 node = strndupa(arg_remote_host, service - arg_remote_host);
439                 service ++;
440         } else {
441                 node = arg_remote_host;
442                 service = "80";
443         }
444
445         log_debug("Looking up address info for %s:%s", node, service);
446         r = sd_resolve_getaddrinfo(c->context->resolve, &c->resolve_query, node, service, &hints, resolve_cb, c);
447         if (r < 0) {
448                 log_error_errno(r, "Failed to resolve remote host: %m");
449                 goto fail;
450         }
451
452         return 0;
453
454 fail:
455         connection_free(c);
456         return 0; /* ignore errors, continue serving */
457 }
458
459 static int add_connection_socket(Context *context, int fd) {
460         Connection *c;
461         int r;
462
463         assert(context);
464         assert(fd >= 0);
465
466         if (set_size(context->connections) > CONNECTIONS_MAX) {
467                 log_warning("Hit connection limit, refusing connection.");
468                 safe_close(fd);
469                 return 0;
470         }
471
472         r = set_ensure_allocated(&context->connections, NULL);
473         if (r < 0) {
474                 log_oom();
475                 return 0;
476         }
477
478         c = new0(Connection, 1);
479         if (!c) {
480                 log_oom();
481                 return 0;
482         }
483
484         c->context = context;
485         c->server_fd = fd;
486         c->client_fd = -1;
487         c->server_to_client_buffer[0] = c->server_to_client_buffer[1] = -1;
488         c->client_to_server_buffer[0] = c->client_to_server_buffer[1] = -1;
489
490         r = set_put(context->connections, c);
491         if (r < 0) {
492                 free(c);
493                 log_oom();
494                 return 0;
495         }
496
497         return resolve_remote(c);
498 }
499
500 static int accept_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) {
501         _cleanup_free_ char *peer = NULL;
502         Context *context = userdata;
503         int nfd = -1, r;
504
505         assert(s);
506         assert(fd >= 0);
507         assert(revents & EPOLLIN);
508         assert(context);
509
510         nfd = accept4(fd, NULL, NULL, SOCK_NONBLOCK|SOCK_CLOEXEC);
511         if (nfd < 0) {
512                 if (errno != -EAGAIN)
513                         log_warning_errno(errno, "Failed to accept() socket: %m");
514         } else {
515                 getpeername_pretty(nfd, &peer);
516                 log_debug("New connection from %s", strna(peer));
517
518                 r = add_connection_socket(context, nfd);
519                 if (r < 0) {
520                         log_error_errno(r, "Failed to accept connection, ignoring: %m");
521                         safe_close(fd);
522                 }
523         }
524
525         r = sd_event_source_set_enabled(s, SD_EVENT_ONESHOT);
526         if (r < 0) {
527                 log_error_errno(r, "Error while re-enabling listener with ONESHOT: %m");
528                 sd_event_exit(context->event, r);
529                 return r;
530         }
531
532         return 1;
533 }
534
535 static int add_listen_socket(Context *context, int fd) {
536         sd_event_source *source;
537         int r;
538
539         assert(context);
540         assert(fd >= 0);
541
542         r = set_ensure_allocated(&context->listen, NULL);
543         if (r < 0) {
544                 log_oom();
545                 return r;
546         }
547
548         r = sd_is_socket(fd, 0, SOCK_STREAM, 1);
549         if (r < 0)
550                 return log_error_errno(r, "Failed to determine socket type: %m");
551         if (r == 0) {
552                 log_error("Passed in socket is not a stream socket.");
553                 return -EINVAL;
554         }
555
556         r = fd_nonblock(fd, true);
557         if (r < 0)
558                 return log_error_errno(r, "Failed to mark file descriptor non-blocking: %m");
559
560         r = sd_event_add_io(context->event, &source, fd, EPOLLIN, accept_cb, context);
561         if (r < 0)
562                 return log_error_errno(r, "Failed to add event source: %m");
563
564         r = set_put(context->listen, source);
565         if (r < 0) {
566                 log_error_errno(r, "Failed to add source to set: %m");
567                 sd_event_source_unref(source);
568                 return r;
569         }
570
571         /* Set the watcher to oneshot in case other processes are also
572          * watching to accept(). */
573         r = sd_event_source_set_enabled(source, SD_EVENT_ONESHOT);
574         if (r < 0)
575                 return log_error_errno(r, "Failed to enable oneshot mode: %m");
576
577         return 0;
578 }
579
580 static void help(void) {
581         printf("%1$s [HOST:PORT]\n"
582                "%1$s [SOCKET]\n\n"
583                "Bidirectionally proxy local sockets to another (possibly remote) socket.\n\n"
584                "  -h --help              Show this help\n"
585                "     --version           Show package version\n",
586                program_invocation_short_name);
587 }
588
589 static int parse_argv(int argc, char *argv[]) {
590
591         enum {
592                 ARG_VERSION = 0x100,
593                 ARG_IGNORE_ENV
594         };
595
596         static const struct option options[] = {
597                 { "help",       no_argument, NULL, 'h'           },
598                 { "version",    no_argument, NULL, ARG_VERSION   },
599                 {}
600         };
601
602         int c;
603
604         assert(argc >= 0);
605         assert(argv);
606
607         while ((c = getopt_long(argc, argv, "h", options, NULL)) >= 0)
608
609                 switch (c) {
610
611                 case 'h':
612                         help();
613                         return 0;
614
615                 case ARG_VERSION:
616                         puts(PACKAGE_STRING);
617                         puts(SYSTEMD_FEATURES);
618                         return 0;
619
620                 case '?':
621                         return -EINVAL;
622
623                 default:
624                         assert_not_reached("Unhandled option");
625                 }
626
627         if (optind >= argc) {
628                 log_error("Not enough parameters.");
629                 return -EINVAL;
630         }
631
632         if (argc != optind+1) {
633                 log_error("Too many parameters.");
634                 return -EINVAL;
635         }
636
637         arg_remote_host = argv[optind];
638         return 1;
639 }
640
641 int main(int argc, char *argv[]) {
642         Context context = {};
643         int r, n, fd;
644
645         log_parse_environment();
646         log_open();
647
648         r = parse_argv(argc, argv);
649         if (r <= 0)
650                 goto finish;
651
652         r = sd_event_default(&context.event);
653         if (r < 0) {
654                 log_error_errno(r, "Failed to allocate event loop: %m");
655                 goto finish;
656         }
657
658         r = sd_resolve_default(&context.resolve);
659         if (r < 0) {
660                 log_error_errno(r, "Failed to allocate resolver: %m");
661                 goto finish;
662         }
663
664         r = sd_resolve_attach_event(context.resolve, context.event, 0);
665         if (r < 0) {
666                 log_error_errno(r, "Failed to attach resolver: %m");
667                 goto finish;
668         }
669
670         sd_event_set_watchdog(context.event, true);
671
672         n = sd_listen_fds(1);
673         if (n < 0) {
674                 log_error("Failed to receive sockets from parent.");
675                 r = n;
676                 goto finish;
677         } else if (n == 0) {
678                 log_error("Didn't get any sockets passed in.");
679                 r = -EINVAL;
680                 goto finish;
681         }
682
683         for (fd = SD_LISTEN_FDS_START; fd < SD_LISTEN_FDS_START + n; fd++) {
684                 r = add_listen_socket(&context, fd);
685                 if (r < 0)
686                         goto finish;
687         }
688
689         r = sd_event_loop(context.event);
690         if (r < 0) {
691                 log_error_errno(r, "Failed to run event loop: %m");
692                 goto finish;
693         }
694
695 finish:
696         context_free(&context);
697
698         return r < 0 ? EXIT_FAILURE : EXIT_SUCCESS;
699 }