chiark / gitweb /
b6a7f1c1ba688ebba63fbd47cc0421f1376ad85f
[elogind.git] / src / socket-proxy / socket-proxyd.c
1 /*-*- Mode: C; c-basic-offset: 8; indent-tabs-mode: nil -*-*/
2
3 /***
4   This file is part of systemd.
5
6   Copyright 2013 David Strauss
7
8   systemd is free software; you can redistribute it and/or modify it
9   under the terms of the GNU Lesser General Public License as published by
10   the Free Software Foundation; either version 2.1 of the License, or
11   (at your option) any later version.
12
13   systemd is distributed in the hope that it will be useful, but
14   WITHOUT ANY WARRANTY; without even the implied warranty of
15   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
16   Lesser General Public License for more details.
17
18   You should have received a copy of the GNU Lesser General Public License
19   along with systemd; If not, see <http://www.gnu.org/licenses/>.
20  ***/
21
22 #include <arpa/inet.h>
23 #include <errno.h>
24 #include <getopt.h>
25 #include <stdio.h>
26 #include <stdlib.h>
27 #include <string.h>
28 #include <netdb.h>
29 #include <sys/fcntl.h>
30 #include <sys/socket.h>
31 #include <sys/un.h>
32 #include <unistd.h>
33
34 #include "sd-daemon.h"
35 #include "sd-event.h"
36 #include "log.h"
37 #include "socket-util.h"
38 #include "util.h"
39 #include "event-util.h"
40 #include "build.h"
41 #include "set.h"
42 #include "path-util.h"
43
44 #define BUFFER_SIZE (256 * 1024)
45 #define CONNECTIONS_MAX 256
46
47 #define _cleanup_freeaddrinfo_ _cleanup_(freeaddrinfop)
48 DEFINE_TRIVIAL_CLEANUP_FUNC(struct addrinfo *, freeaddrinfo);
49
50 typedef struct Context {
51         Set *listen;
52         Set *connections;
53 } Context;
54
55 typedef struct Connection {
56         int server_fd, client_fd;
57         int server_to_client_buffer[2]; /* a pipe */
58         int client_to_server_buffer[2]; /* a pipe */
59
60         size_t server_to_client_buffer_full, client_to_server_buffer_full;
61         size_t server_to_client_buffer_size, client_to_server_buffer_size;
62
63         sd_event_source *server_event_source, *client_event_source;
64 } Connection;
65
66 static const char *arg_remote_host = NULL;
67
68 static void connection_free(Connection *c) {
69         assert(c);
70
71         sd_event_source_unref(c->server_event_source);
72         sd_event_source_unref(c->client_event_source);
73
74         if (c->server_fd >= 0)
75                 close_nointr_nofail(c->server_fd);
76         if (c->client_fd >= 0)
77                 close_nointr_nofail(c->client_fd);
78
79         close_pipe(c->server_to_client_buffer);
80         close_pipe(c->client_to_server_buffer);
81
82         free(c);
83 }
84
85 static void context_free(Context *context) {
86         sd_event_source *es;
87         Connection *c;
88
89         assert(context);
90
91         while ((es = set_steal_first(context->listen)))
92                 sd_event_source_unref(es);
93
94         while ((c = set_steal_first(context->connections)))
95                 connection_free(c);
96
97         set_free(context->listen);
98         set_free(context->connections);
99 }
100
101 static int get_remote_sockaddr(union sockaddr_union *sa, socklen_t *salen) {
102         int r;
103
104         assert(sa);
105         assert(salen);
106
107         if (path_is_absolute(arg_remote_host)) {
108                 sa->un.sun_family = AF_UNIX;
109                 strncpy(sa->un.sun_path, arg_remote_host, sizeof(sa->un.sun_path)-1);
110                 sa->un.sun_path[sizeof(sa->un.sun_path)-1] = 0;
111
112                 *salen = offsetof(union sockaddr_union, un.sun_path) + strlen(sa->un.sun_path);
113
114         } else if (arg_remote_host[0] == '@') {
115                 sa->un.sun_family = AF_UNIX;
116                 sa->un.sun_path[0] = 0;
117                 strncpy(sa->un.sun_path+1, arg_remote_host+1, sizeof(sa->un.sun_path)-2);
118                 sa->un.sun_path[sizeof(sa->un.sun_path)-1] = 0;
119
120                 *salen = offsetof(union sockaddr_union, un.sun_path) + 1 + strlen(sa->un.sun_path + 1);
121
122         } else {
123                 _cleanup_freeaddrinfo_ struct addrinfo *result = NULL;
124                 const char *node, *service;
125
126                 struct addrinfo hints = {
127                         .ai_family = AF_UNSPEC,
128                         .ai_socktype = SOCK_STREAM,
129                         .ai_flags = AI_ADDRCONFIG
130                 };
131
132                 service = strrchr(arg_remote_host, ':');
133                 if (service) {
134                         node = strndupa(arg_remote_host, service - arg_remote_host);
135                         service ++;
136                 } else {
137                         node = arg_remote_host;
138                         service = "80";
139                 }
140
141                 log_debug("Looking up address info for %s:%s", node, service);
142                 r = getaddrinfo(node, service, &hints, &result);
143                 if (r != 0) {
144                         log_error("Failed to resolve host %s:%s: %s", node, service, gai_strerror(r));
145                         return -EHOSTUNREACH;
146                 }
147
148                 assert(result);
149                 if (result->ai_addrlen > sizeof(union sockaddr_union)) {
150                         log_error("Address too long.");
151                         return -E2BIG;
152                 }
153
154                 memcpy(sa, result->ai_addr, result->ai_addrlen);
155                 *salen = result->ai_addrlen;
156         }
157
158         return 0;
159 }
160
161 static int connection_create_pipes(Connection *c, int buffer[2], size_t *sz) {
162         int r;
163
164         assert(c);
165         assert(buffer);
166         assert(sz);
167
168         if (buffer[0] >= 0)
169                 return 0;
170
171         r = pipe2(buffer, O_CLOEXEC|O_NONBLOCK);
172         if (r < 0) {
173                 log_error("Failed to allocate pipe buffer: %m");
174                 return -errno;
175         }
176
177         fcntl(buffer[0], F_SETPIPE_SZ, BUFFER_SIZE);
178
179         r = fcntl(buffer[0], F_GETPIPE_SZ);
180         if (r < 0) {
181                 log_error("Failed to get pipe buffer size: %m");
182                 return -errno;
183         }
184
185         assert(r > 0);
186         *sz = r;
187
188         return 0;
189 }
190
191 static int connection_shovel(
192                 Connection *c,
193                 int *from, int buffer[2], int *to,
194                 size_t *full, size_t *sz,
195                 sd_event_source **from_source, sd_event_source **to_source) {
196
197         bool shoveled;
198
199         assert(c);
200         assert(from);
201         assert(buffer);
202         assert(buffer[0] >= 0);
203         assert(buffer[1] >= 0);
204         assert(to);
205         assert(full);
206         assert(sz);
207         assert(from_source);
208         assert(to_source);
209
210         do {
211                 ssize_t z;
212
213                 shoveled = false;
214
215                 if (*full < *sz && *from >= 0 && *to >= 0) {
216                         z = splice(*from, NULL, buffer[1], NULL, *sz - *full, SPLICE_F_MOVE|SPLICE_F_NONBLOCK);
217                         if (z > 0) {
218                                 *full += z;
219                                 shoveled = true;
220                         } else if (z == 0 || errno == EPIPE || errno == ECONNRESET) {
221                                 *from_source = sd_event_source_unref(*from_source);
222                                 close_nointr_nofail(*from);
223                                 *from = -1;
224                         } else if (errno != EAGAIN && errno != EINTR) {
225                                 log_error("Failed to splice: %m");
226                                 return -errno;
227                         }
228                 }
229
230                 if (*full > 0 && *to >= 0) {
231                         z = splice(buffer[0], NULL, *to, NULL, *full, SPLICE_F_MOVE|SPLICE_F_NONBLOCK);
232                         if (z > 0) {
233                                 *full -= z;
234                                 shoveled = true;
235                         } else if (z == 0 || errno == EPIPE || errno == ECONNRESET) {
236                                 *to_source = sd_event_source_unref(*to_source);
237                                 close_nointr_nofail(*to);
238                                 *to = -1;
239                         } else if (errno != EAGAIN && errno != EINTR) {
240                                 log_error("Failed to splice: %m");
241                                 return -errno;
242                         }
243                 }
244         } while (shoveled);
245
246         return 0;
247 }
248
249 static int connection_enable_event_sources(Connection *c, sd_event *event);
250
251 static int traffic_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) {
252         Connection *c = userdata;
253         int r;
254
255         assert(s);
256         assert(fd >= 0);
257         assert(c);
258
259         r = connection_shovel(c,
260                               &c->server_fd, c->server_to_client_buffer, &c->client_fd,
261                               &c->server_to_client_buffer_full, &c->server_to_client_buffer_size,
262                               &c->server_event_source, &c->client_event_source);
263         if (r < 0)
264                 goto quit;
265
266         r = connection_shovel(c,
267                               &c->client_fd, c->client_to_server_buffer, &c->server_fd,
268                               &c->client_to_server_buffer_full, &c->client_to_server_buffer_size,
269                               &c->client_event_source, &c->server_event_source);
270         if (r < 0)
271                 goto quit;
272
273         /* EOF on both sides? */
274         if (c->server_fd == -1 && c->client_fd == -1)
275                 goto quit;
276
277         /* Server closed, and all data written to client? */
278         if (c->server_fd == -1 && c->server_to_client_buffer_full <= 0)
279                 goto quit;
280
281         /* Client closed, and all data written to server? */
282         if (c->client_fd == -1 && c->client_to_server_buffer_full <= 0)
283                 goto quit;
284
285         r = connection_enable_event_sources(c, sd_event_get(s));
286         if (r < 0)
287                 goto quit;
288
289         return 1;
290
291 quit:
292         connection_free(c);
293         return 0; /* ignore errors, continue serving */
294 }
295
296 static int connection_enable_event_sources(Connection *c, sd_event *event) {
297         uint32_t a = 0, b = 0;
298         int r;
299
300         assert(c);
301         assert(event);
302
303         if (c->server_to_client_buffer_full > 0)
304                 b |= EPOLLOUT;
305         if (c->server_to_client_buffer_full < c->server_to_client_buffer_size)
306                 a |= EPOLLIN;
307
308         if (c->client_to_server_buffer_full > 0)
309                 a |= EPOLLOUT;
310         if (c->client_to_server_buffer_full < c->client_to_server_buffer_size)
311                 b |= EPOLLIN;
312
313         if (c->server_event_source)
314                 r = sd_event_source_set_io_events(c->server_event_source, a);
315         else if (c->server_fd >= 0)
316                 r = sd_event_add_io(event, c->server_fd, a, traffic_cb, c, &c->server_event_source);
317         else
318                 r = 0;
319
320         if (r < 0) {
321                 log_error("Failed to set up server event source: %s", strerror(-r));
322                 return r;
323         }
324
325         if (c->client_event_source)
326                 r = sd_event_source_set_io_events(c->client_event_source, b);
327         else if (c->client_fd >= 0)
328                 r = sd_event_add_io(event, c->client_fd, b, traffic_cb, c, &c->client_event_source);
329         else
330                 r = 0;
331
332         if (r < 0) {
333                 log_error("Failed to set up client event source: %s", strerror(-r));
334                 return r;
335         }
336
337         return 0;
338 }
339
340 static int connect_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) {
341         Connection *c = userdata;
342         socklen_t solen;
343         int error, r;
344
345         assert(s);
346         assert(fd >= 0);
347         assert(c);
348
349         solen = sizeof(error);
350         r = getsockopt(fd, SOL_SOCKET, SO_ERROR, &error, &solen);
351         if (r < 0) {
352                 log_error("Failed to issue SO_ERROR: %m");
353                 goto fail;
354         }
355
356         if (error != 0) {
357                 log_error("Failed to connect to remote host: %s", strerror(error));
358                 goto fail;
359         }
360
361         c->client_event_source = sd_event_source_unref(c->client_event_source);
362
363         r = connection_create_pipes(c, c->server_to_client_buffer, &c->server_to_client_buffer_size);
364         if (r < 0)
365                 goto fail;
366
367         r = connection_create_pipes(c, c->client_to_server_buffer, &c->client_to_server_buffer_size);
368         if (r < 0)
369                 goto fail;
370
371         r = connection_enable_event_sources(c, sd_event_get(s));
372         if (r < 0)
373                 goto fail;
374
375         return 0;
376
377 fail:
378         connection_free(c);
379         return 0; /* ignore errors, continue serving */
380 }
381
382 static int add_connection_socket(Context *context, sd_event *event, int fd) {
383         union sockaddr_union sa = {};
384         socklen_t salen;
385         Connection *c;
386         int r;
387
388         assert(context);
389         assert(event);
390         assert(fd >= 0);
391
392         if (set_size(context->connections) > CONNECTIONS_MAX) {
393                 log_warning("Hit connection limit, refusing connection.");
394                 close_nointr_nofail(fd);
395                 return 0;
396         }
397
398         r = set_ensure_allocated(&context->connections, trivial_hash_func, trivial_compare_func);
399         if (r < 0)
400                 return log_oom();
401
402         c = new0(Connection, 1);
403         if (!c)
404                 return log_oom();
405
406         c->server_fd = fd;
407         c->client_fd = -1;
408         c->server_to_client_buffer[0] = c->server_to_client_buffer[1] = -1;
409         c->client_to_server_buffer[0] = c->client_to_server_buffer[1] = -1;
410
411         r = get_remote_sockaddr(&sa, &salen);
412         if (r < 0)
413                 goto fail;
414
415         c->client_fd = socket(sa.sa.sa_family, SOCK_STREAM|SOCK_NONBLOCK|SOCK_CLOEXEC, 0);
416         if (c->client_fd < 0) {
417                 log_error("Failed to get remote socket: %m");
418                 goto fail;
419         }
420
421         r = connect(c->client_fd, &sa.sa, salen);
422         if (r < 0) {
423                 if (errno == EINPROGRESS) {
424                         r = sd_event_add_io(event, c->client_fd, EPOLLOUT, connect_cb, c, &c->client_event_source);
425                         if (r < 0) {
426                                 log_error("Failed to add connection socket: %s", strerror(-r));
427                                 goto fail;
428                         }
429
430                         r = sd_event_source_set_enabled(c->client_event_source, SD_EVENT_ONESHOT);
431                         if (r < 0) {
432                                 log_error("Failed to enable oneshot event source: %s", strerror(-r));
433                                 goto fail;
434                         }
435                 } else {
436                         log_error("Failed to connect to remote host: %m");
437                         goto fail;
438                 }
439         } else {
440                 r = connection_enable_event_sources(c, event);
441                 if (r < 0)
442                         goto fail;
443         }
444
445         return 0;
446
447 fail:
448         connection_free(c);
449         return 0; /* ignore non-OOM errors, continue serving */
450 }
451
452 static int accept_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) {
453         Context *context = userdata;
454         int nfd = -1, r;
455
456         assert(s);
457         assert(fd >= 0);
458         assert(revents & EPOLLIN);
459         assert(context);
460
461         nfd = accept4(fd, NULL, NULL, SOCK_NONBLOCK|SOCK_CLOEXEC);
462         if (nfd >= 0) {
463                 _cleanup_free_ char *peer = NULL;
464
465                 getpeername_pretty(nfd, &peer);
466                 log_debug("New connection from %s", strna(peer));
467
468                 r = add_connection_socket(context, sd_event_get(s), nfd);
469                 if (r < 0) {
470                         close_nointr_nofail(fd);
471                         return r;
472                 }
473
474         } else if (errno != -EAGAIN)
475                 log_warning("Failed to accept() socket: %m");
476
477         r = sd_event_source_set_enabled(s, SD_EVENT_ONESHOT);
478         if (r < 0) {
479                 log_error("Error %d while re-enabling listener with ONESHOT: %s", r, strerror(-r));
480                 return r;
481         }
482
483         return 1;
484 }
485
486 static int add_listen_socket(Context *context, sd_event *event, int fd) {
487         sd_event_source *source;
488         int r;
489
490         assert(context);
491         assert(event);
492         assert(fd >= 0);
493
494         log_info("Listening on %i", fd);
495
496         r = set_ensure_allocated(&context->listen, trivial_hash_func, trivial_compare_func);
497         if (r < 0) {
498                 log_oom();
499                 return r;
500         }
501
502         r = sd_is_socket(fd, 0, SOCK_STREAM, 1);
503         if (r < 0) {
504                 log_error("Failed to determine socket type: %s", strerror(-r));
505                 return r;
506         }
507         if (r == 0) {
508                 log_error("Passed in socket is not a stream socket.");
509                 return -EINVAL;
510         }
511
512         r = fd_nonblock(fd, true);
513         if (r < 0) {
514                 log_error("Failed to mark file descriptor non-blocking: %s", strerror(-r));
515                 return r;
516         }
517
518         r = sd_event_add_io(event, fd, EPOLLIN, accept_cb, context, &source);
519         if (r < 0) {
520                 log_error("Failed to add event source: %s", strerror(-r));
521                 return r;
522         }
523
524         r = set_put(context->listen, source);
525         if (r < 0) {
526                 log_error("Failed to add source to set: %s", strerror(-r));
527                 sd_event_source_unref(source);
528                 return r;
529         }
530
531         /* Set the watcher to oneshot in case other processes are also
532          * watching to accept(). */
533         r = sd_event_source_set_enabled(source, SD_EVENT_ONESHOT);
534         if (r < 0) {
535                 log_error("Failed to enable oneshot mode: %s", strerror(-r));
536                 return r;
537         }
538
539         return 0;
540 }
541
542 static int help(void) {
543
544         printf("%s [HOST:PORT]\n"
545                "%s [SOCKET]\n\n"
546                "Bidirectionally proxy local sockets to another (possibly remote) socket.\n\n"
547                "  -h --help              Show this help\n"
548                "     --version           Show package version\n",
549                program_invocation_short_name,
550                program_invocation_short_name);
551
552         return 0;
553 }
554
555 static int parse_argv(int argc, char *argv[]) {
556
557         enum {
558                 ARG_VERSION = 0x100,
559                 ARG_IGNORE_ENV
560         };
561
562         static const struct option options[] = {
563                 { "help",       no_argument, NULL, 'h'           },
564                 { "version",    no_argument, NULL, ARG_VERSION   },
565                 {}
566         };
567
568         int c;
569
570         assert(argc >= 0);
571         assert(argv);
572
573         while ((c = getopt_long(argc, argv, "h", options, NULL)) >= 0) {
574
575                 switch (c) {
576
577                 case 'h':
578                         return help();
579
580                 case ARG_VERSION:
581                         puts(PACKAGE_STRING);
582                         puts(SYSTEMD_FEATURES);
583                         return 0;
584
585                 case '?':
586                         return -EINVAL;
587
588                 default:
589                         assert_not_reached("Unhandled option");
590                 }
591         }
592
593         if (optind >= argc) {
594                 log_error("Not enough parameters.");
595                 return -EINVAL;
596         }
597
598         if (argc != optind+1) {
599                 log_error("Too many parameters.");
600                 return -EINVAL;
601         }
602
603         arg_remote_host = argv[optind];
604         return 1;
605 }
606
607 int main(int argc, char *argv[]) {
608         _cleanup_event_unref_ sd_event *event = NULL;
609         Context context = {};
610         int r, n, fd;
611
612         log_parse_environment();
613         log_open();
614
615         r = parse_argv(argc, argv);
616         if (r <= 0)
617                 goto finish;
618
619         r = sd_event_new(&event);
620         if (r < 0) {
621                 log_error("Failed to allocate event loop: %s", strerror(-r));
622                 goto finish;
623         }
624
625         n = sd_listen_fds(1);
626         if (n < 0) {
627                 log_error("Failed to receive sockets from parent.");
628                 r = n;
629                 goto finish;
630         } else if (n == 0) {
631                 log_error("Didn't get any sockets passed in.");
632                 r = -EINVAL;
633                 goto finish;
634         }
635
636         for (fd = SD_LISTEN_FDS_START; fd < SD_LISTEN_FDS_START + n; fd++) {
637                 r = add_listen_socket(&context, event, fd);
638                 if (r < 0)
639                         goto finish;
640         }
641
642         r = sd_event_loop(event);
643         if (r < 0) {
644                 log_error("Failed to run event loop: %s", strerror(-r));
645                 goto finish;
646         }
647
648 finish:
649         context_free(&context);
650
651         return r < 0 ? EXIT_FAILURE : EXIT_SUCCESS;
652 }