chiark / gitweb /
remove unused includes
[elogind.git] / src / socket-proxy / socket-proxyd.c
1 /*-*- Mode: C; c-basic-offset: 8; indent-tabs-mode: nil -*-*/
2
3 /***
4   This file is part of systemd.
5
6   Copyright 2013 David Strauss
7
8   systemd is free software; you can redistribute it and/or modify it
9   under the terms of the GNU Lesser General Public License as published by
10   the Free Software Foundation; either version 2.1 of the License, or
11   (at your option) any later version.
12
13   systemd is distributed in the hope that it will be useful, but
14   WITHOUT ANY WARRANTY; without even the implied warranty of
15   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
16   Lesser General Public License for more details.
17
18   You should have received a copy of the GNU Lesser General Public License
19   along with systemd; If not, see <http://www.gnu.org/licenses/>.
20  ***/
21
22 #include <errno.h>
23 #include <getopt.h>
24 #include <stdio.h>
25 #include <stdlib.h>
26 #include <string.h>
27 #include <netdb.h>
28 #include <fcntl.h>
29 #include <sys/socket.h>
30 #include <sys/un.h>
31 #include <unistd.h>
32
33 #include "sd-daemon.h"
34 #include "sd-event.h"
35 #include "sd-resolve.h"
36 #include "log.h"
37 #include "socket-util.h"
38 #include "util.h"
39 #include "build.h"
40 #include "set.h"
41 #include "path-util.h"
42
43 #define BUFFER_SIZE (256 * 1024)
44 #define CONNECTIONS_MAX 256
45
46 static const char *arg_remote_host = NULL;
47
48 typedef struct Context {
49         sd_event *event;
50         sd_resolve *resolve;
51
52         Set *listen;
53         Set *connections;
54 } Context;
55
56 typedef struct Connection {
57         Context *context;
58
59         int server_fd, client_fd;
60         int server_to_client_buffer[2]; /* a pipe */
61         int client_to_server_buffer[2]; /* a pipe */
62
63         size_t server_to_client_buffer_full, client_to_server_buffer_full;
64         size_t server_to_client_buffer_size, client_to_server_buffer_size;
65
66         sd_event_source *server_event_source, *client_event_source;
67
68         sd_resolve_query *resolve_query;
69 } Connection;
70
71 static void connection_free(Connection *c) {
72         assert(c);
73
74         if (c->context)
75                 set_remove(c->context->connections, c);
76
77         sd_event_source_unref(c->server_event_source);
78         sd_event_source_unref(c->client_event_source);
79
80         safe_close(c->server_fd);
81         safe_close(c->client_fd);
82
83         safe_close_pair(c->server_to_client_buffer);
84         safe_close_pair(c->client_to_server_buffer);
85
86         sd_resolve_query_unref(c->resolve_query);
87
88         free(c);
89 }
90
91 static void context_free(Context *context) {
92         sd_event_source *es;
93         Connection *c;
94
95         assert(context);
96
97         while ((es = set_steal_first(context->listen)))
98                 sd_event_source_unref(es);
99
100         while ((c = set_first(context->connections)))
101                 connection_free(c);
102
103         set_free(context->listen);
104         set_free(context->connections);
105
106         sd_event_unref(context->event);
107         sd_resolve_unref(context->resolve);
108 }
109
110 static int connection_create_pipes(Connection *c, int buffer[2], size_t *sz) {
111         int r;
112
113         assert(c);
114         assert(buffer);
115         assert(sz);
116
117         if (buffer[0] >= 0)
118                 return 0;
119
120         r = pipe2(buffer, O_CLOEXEC|O_NONBLOCK);
121         if (r < 0)
122                 return log_error_errno(errno, "Failed to allocate pipe buffer: %m");
123
124         (void) fcntl(buffer[0], F_SETPIPE_SZ, BUFFER_SIZE);
125
126         r = fcntl(buffer[0], F_GETPIPE_SZ);
127         if (r < 0)
128                 return log_error_errno(errno, "Failed to get pipe buffer size: %m");
129
130         assert(r > 0);
131         *sz = r;
132
133         return 0;
134 }
135
136 static int connection_shovel(
137                 Connection *c,
138                 int *from, int buffer[2], int *to,
139                 size_t *full, size_t *sz,
140                 sd_event_source **from_source, sd_event_source **to_source) {
141
142         bool shoveled;
143
144         assert(c);
145         assert(from);
146         assert(buffer);
147         assert(buffer[0] >= 0);
148         assert(buffer[1] >= 0);
149         assert(to);
150         assert(full);
151         assert(sz);
152         assert(from_source);
153         assert(to_source);
154
155         do {
156                 ssize_t z;
157
158                 shoveled = false;
159
160                 if (*full < *sz && *from >= 0 && *to >= 0) {
161                         z = splice(*from, NULL, buffer[1], NULL, *sz - *full, SPLICE_F_MOVE|SPLICE_F_NONBLOCK);
162                         if (z > 0) {
163                                 *full += z;
164                                 shoveled = true;
165                         } else if (z == 0 || errno == EPIPE || errno == ECONNRESET) {
166                                 *from_source = sd_event_source_unref(*from_source);
167                                 *from = safe_close(*from);
168                         } else if (errno != EAGAIN && errno != EINTR)
169                                 return log_error_errno(errno, "Failed to splice: %m");
170                 }
171
172                 if (*full > 0 && *to >= 0) {
173                         z = splice(buffer[0], NULL, *to, NULL, *full, SPLICE_F_MOVE|SPLICE_F_NONBLOCK);
174                         if (z > 0) {
175                                 *full -= z;
176                                 shoveled = true;
177                         } else if (z == 0 || errno == EPIPE || errno == ECONNRESET) {
178                                 *to_source = sd_event_source_unref(*to_source);
179                                 *to = safe_close(*to);
180                         } else if (errno != EAGAIN && errno != EINTR)
181                                 return log_error_errno(errno, "Failed to splice: %m");
182                 }
183         } while (shoveled);
184
185         return 0;
186 }
187
188 static int connection_enable_event_sources(Connection *c);
189
190 static int traffic_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) {
191         Connection *c = userdata;
192         int r;
193
194         assert(s);
195         assert(fd >= 0);
196         assert(c);
197
198         r = connection_shovel(c,
199                               &c->server_fd, c->server_to_client_buffer, &c->client_fd,
200                               &c->server_to_client_buffer_full, &c->server_to_client_buffer_size,
201                               &c->server_event_source, &c->client_event_source);
202         if (r < 0)
203                 goto quit;
204
205         r = connection_shovel(c,
206                               &c->client_fd, c->client_to_server_buffer, &c->server_fd,
207                               &c->client_to_server_buffer_full, &c->client_to_server_buffer_size,
208                               &c->client_event_source, &c->server_event_source);
209         if (r < 0)
210                 goto quit;
211
212         /* EOF on both sides? */
213         if (c->server_fd == -1 && c->client_fd == -1)
214                 goto quit;
215
216         /* Server closed, and all data written to client? */
217         if (c->server_fd == -1 && c->server_to_client_buffer_full <= 0)
218                 goto quit;
219
220         /* Client closed, and all data written to server? */
221         if (c->client_fd == -1 && c->client_to_server_buffer_full <= 0)
222                 goto quit;
223
224         r = connection_enable_event_sources(c);
225         if (r < 0)
226                 goto quit;
227
228         return 1;
229
230 quit:
231         connection_free(c);
232         return 0; /* ignore errors, continue serving */
233 }
234
235 static int connection_enable_event_sources(Connection *c) {
236         uint32_t a = 0, b = 0;
237         int r;
238
239         assert(c);
240
241         if (c->server_to_client_buffer_full > 0)
242                 b |= EPOLLOUT;
243         if (c->server_to_client_buffer_full < c->server_to_client_buffer_size)
244                 a |= EPOLLIN;
245
246         if (c->client_to_server_buffer_full > 0)
247                 a |= EPOLLOUT;
248         if (c->client_to_server_buffer_full < c->client_to_server_buffer_size)
249                 b |= EPOLLIN;
250
251         if (c->server_event_source)
252                 r = sd_event_source_set_io_events(c->server_event_source, a);
253         else if (c->server_fd >= 0)
254                 r = sd_event_add_io(c->context->event, &c->server_event_source, c->server_fd, a, traffic_cb, c);
255         else
256                 r = 0;
257
258         if (r < 0)
259                 return log_error_errno(r, "Failed to set up server event source: %m");
260
261         if (c->client_event_source)
262                 r = sd_event_source_set_io_events(c->client_event_source, b);
263         else if (c->client_fd >= 0)
264                 r = sd_event_add_io(c->context->event, &c->client_event_source, c->client_fd, b, traffic_cb, c);
265         else
266                 r = 0;
267
268         if (r < 0)
269                 return log_error_errno(r, "Failed to set up client event source: %m");
270
271         return 0;
272 }
273
274 static int connection_complete(Connection *c) {
275         int r;
276
277         assert(c);
278
279         r = connection_create_pipes(c, c->server_to_client_buffer, &c->server_to_client_buffer_size);
280         if (r < 0)
281                 goto fail;
282
283         r = connection_create_pipes(c, c->client_to_server_buffer, &c->client_to_server_buffer_size);
284         if (r < 0)
285                 goto fail;
286
287         r = connection_enable_event_sources(c);
288         if (r < 0)
289                 goto fail;
290
291         return 0;
292
293 fail:
294         connection_free(c);
295         return 0; /* ignore errors, continue serving */
296 }
297
298 static int connect_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) {
299         Connection *c = userdata;
300         socklen_t solen;
301         int error, r;
302
303         assert(s);
304         assert(fd >= 0);
305         assert(c);
306
307         solen = sizeof(error);
308         r = getsockopt(fd, SOL_SOCKET, SO_ERROR, &error, &solen);
309         if (r < 0) {
310                 log_error_errno(errno, "Failed to issue SO_ERROR: %m");
311                 goto fail;
312         }
313
314         if (error != 0) {
315                 log_error_errno(error, "Failed to connect to remote host: %m");
316                 goto fail;
317         }
318
319         c->client_event_source = sd_event_source_unref(c->client_event_source);
320
321         return connection_complete(c);
322
323 fail:
324         connection_free(c);
325         return 0; /* ignore errors, continue serving */
326 }
327
328 static int connection_start(Connection *c, struct sockaddr *sa, socklen_t salen) {
329         int r;
330
331         assert(c);
332         assert(sa);
333         assert(salen);
334
335         c->client_fd = socket(sa->sa_family, SOCK_STREAM|SOCK_NONBLOCK|SOCK_CLOEXEC, 0);
336         if (c->client_fd < 0) {
337                 log_error_errno(errno, "Failed to get remote socket: %m");
338                 goto fail;
339         }
340
341         r = connect(c->client_fd, sa, salen);
342         if (r < 0) {
343                 if (errno == EINPROGRESS) {
344                         r = sd_event_add_io(c->context->event, &c->client_event_source, c->client_fd, EPOLLOUT, connect_cb, c);
345                         if (r < 0) {
346                                 log_error_errno(r, "Failed to add connection socket: %m");
347                                 goto fail;
348                         }
349
350                         r = sd_event_source_set_enabled(c->client_event_source, SD_EVENT_ONESHOT);
351                         if (r < 0) {
352                                 log_error_errno(r, "Failed to enable oneshot event source: %m");
353                                 goto fail;
354                         }
355                 } else {
356                         log_error_errno(errno, "Failed to connect to remote host: %m");
357                         goto fail;
358                 }
359         } else {
360                 r = connection_complete(c);
361                 if (r < 0)
362                         goto fail;
363         }
364
365         return 0;
366
367 fail:
368         connection_free(c);
369         return 0; /* ignore errors, continue serving */
370 }
371
372 static int resolve_cb(sd_resolve_query *q, int ret, const struct addrinfo *ai, void *userdata) {
373         Connection *c = userdata;
374
375         assert(q);
376         assert(c);
377
378         if (ret != 0) {
379                 log_error("Failed to resolve host: %s", gai_strerror(ret));
380                 goto fail;
381         }
382
383         c->resolve_query = sd_resolve_query_unref(c->resolve_query);
384
385         return connection_start(c, ai->ai_addr, ai->ai_addrlen);
386
387 fail:
388         connection_free(c);
389         return 0; /* ignore errors, continue serving */
390 }
391
392 static int resolve_remote(Connection *c) {
393
394         static const struct addrinfo hints = {
395                 .ai_family = AF_UNSPEC,
396                 .ai_socktype = SOCK_STREAM,
397                 .ai_flags = AI_ADDRCONFIG
398         };
399
400         union sockaddr_union sa = {};
401         const char *node, *service;
402         socklen_t salen;
403         int r;
404
405         if (path_is_absolute(arg_remote_host)) {
406                 sa.un.sun_family = AF_UNIX;
407                 strncpy(sa.un.sun_path, arg_remote_host, sizeof(sa.un.sun_path)-1);
408                 sa.un.sun_path[sizeof(sa.un.sun_path)-1] = 0;
409
410                 salen = offsetof(union sockaddr_union, un.sun_path) + strlen(sa.un.sun_path);
411
412                 return connection_start(c, &sa.sa, salen);
413         }
414
415         if (arg_remote_host[0] == '@') {
416                 sa.un.sun_family = AF_UNIX;
417                 sa.un.sun_path[0] = 0;
418                 strncpy(sa.un.sun_path+1, arg_remote_host+1, sizeof(sa.un.sun_path)-2);
419                 sa.un.sun_path[sizeof(sa.un.sun_path)-1] = 0;
420
421                 salen = offsetof(union sockaddr_union, un.sun_path) + 1 + strlen(sa.un.sun_path + 1);
422
423                 return connection_start(c, &sa.sa, salen);
424         }
425
426         service = strrchr(arg_remote_host, ':');
427         if (service) {
428                 node = strndupa(arg_remote_host, service - arg_remote_host);
429                 service ++;
430         } else {
431                 node = arg_remote_host;
432                 service = "80";
433         }
434
435         log_debug("Looking up address info for %s:%s", node, service);
436         r = sd_resolve_getaddrinfo(c->context->resolve, &c->resolve_query, node, service, &hints, resolve_cb, c);
437         if (r < 0) {
438                 log_error_errno(r, "Failed to resolve remote host: %m");
439                 goto fail;
440         }
441
442         return 0;
443
444 fail:
445         connection_free(c);
446         return 0; /* ignore errors, continue serving */
447 }
448
449 static int add_connection_socket(Context *context, int fd) {
450         Connection *c;
451         int r;
452
453         assert(context);
454         assert(fd >= 0);
455
456         if (set_size(context->connections) > CONNECTIONS_MAX) {
457                 log_warning("Hit connection limit, refusing connection.");
458                 safe_close(fd);
459                 return 0;
460         }
461
462         r = set_ensure_allocated(&context->connections, NULL);
463         if (r < 0) {
464                 log_oom();
465                 return 0;
466         }
467
468         c = new0(Connection, 1);
469         if (!c) {
470                 log_oom();
471                 return 0;
472         }
473
474         c->context = context;
475         c->server_fd = fd;
476         c->client_fd = -1;
477         c->server_to_client_buffer[0] = c->server_to_client_buffer[1] = -1;
478         c->client_to_server_buffer[0] = c->client_to_server_buffer[1] = -1;
479
480         r = set_put(context->connections, c);
481         if (r < 0) {
482                 free(c);
483                 log_oom();
484                 return 0;
485         }
486
487         return resolve_remote(c);
488 }
489
490 static int accept_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) {
491         _cleanup_free_ char *peer = NULL;
492         Context *context = userdata;
493         int nfd = -1, r;
494
495         assert(s);
496         assert(fd >= 0);
497         assert(revents & EPOLLIN);
498         assert(context);
499
500         nfd = accept4(fd, NULL, NULL, SOCK_NONBLOCK|SOCK_CLOEXEC);
501         if (nfd < 0) {
502                 if (errno != -EAGAIN)
503                         log_warning_errno(errno, "Failed to accept() socket: %m");
504         } else {
505                 getpeername_pretty(nfd, &peer);
506                 log_debug("New connection from %s", strna(peer));
507
508                 r = add_connection_socket(context, nfd);
509                 if (r < 0) {
510                         log_error_errno(r, "Failed to accept connection, ignoring: %m");
511                         safe_close(fd);
512                 }
513         }
514
515         r = sd_event_source_set_enabled(s, SD_EVENT_ONESHOT);
516         if (r < 0) {
517                 log_error_errno(r, "Error while re-enabling listener with ONESHOT: %m");
518                 sd_event_exit(context->event, r);
519                 return r;
520         }
521
522         return 1;
523 }
524
525 static int add_listen_socket(Context *context, int fd) {
526         sd_event_source *source;
527         int r;
528
529         assert(context);
530         assert(fd >= 0);
531
532         r = set_ensure_allocated(&context->listen, NULL);
533         if (r < 0) {
534                 log_oom();
535                 return r;
536         }
537
538         r = sd_is_socket(fd, 0, SOCK_STREAM, 1);
539         if (r < 0)
540                 return log_error_errno(r, "Failed to determine socket type: %m");
541         if (r == 0) {
542                 log_error("Passed in socket is not a stream socket.");
543                 return -EINVAL;
544         }
545
546         r = fd_nonblock(fd, true);
547         if (r < 0)
548                 return log_error_errno(r, "Failed to mark file descriptor non-blocking: %m");
549
550         r = sd_event_add_io(context->event, &source, fd, EPOLLIN, accept_cb, context);
551         if (r < 0)
552                 return log_error_errno(r, "Failed to add event source: %m");
553
554         r = set_put(context->listen, source);
555         if (r < 0) {
556                 log_error_errno(r, "Failed to add source to set: %m");
557                 sd_event_source_unref(source);
558                 return r;
559         }
560
561         /* Set the watcher to oneshot in case other processes are also
562          * watching to accept(). */
563         r = sd_event_source_set_enabled(source, SD_EVENT_ONESHOT);
564         if (r < 0)
565                 return log_error_errno(r, "Failed to enable oneshot mode: %m");
566
567         return 0;
568 }
569
570 static void help(void) {
571         printf("%1$s [HOST:PORT]\n"
572                "%1$s [SOCKET]\n\n"
573                "Bidirectionally proxy local sockets to another (possibly remote) socket.\n\n"
574                "  -h --help              Show this help\n"
575                "     --version           Show package version\n",
576                program_invocation_short_name);
577 }
578
579 static int parse_argv(int argc, char *argv[]) {
580
581         enum {
582                 ARG_VERSION = 0x100,
583                 ARG_IGNORE_ENV
584         };
585
586         static const struct option options[] = {
587                 { "help",       no_argument, NULL, 'h'           },
588                 { "version",    no_argument, NULL, ARG_VERSION   },
589                 {}
590         };
591
592         int c;
593
594         assert(argc >= 0);
595         assert(argv);
596
597         while ((c = getopt_long(argc, argv, "h", options, NULL)) >= 0)
598
599                 switch (c) {
600
601                 case 'h':
602                         help();
603                         return 0;
604
605                 case ARG_VERSION:
606                         puts(PACKAGE_STRING);
607                         puts(SYSTEMD_FEATURES);
608                         return 0;
609
610                 case '?':
611                         return -EINVAL;
612
613                 default:
614                         assert_not_reached("Unhandled option");
615                 }
616
617         if (optind >= argc) {
618                 log_error("Not enough parameters.");
619                 return -EINVAL;
620         }
621
622         if (argc != optind+1) {
623                 log_error("Too many parameters.");
624                 return -EINVAL;
625         }
626
627         arg_remote_host = argv[optind];
628         return 1;
629 }
630
631 int main(int argc, char *argv[]) {
632         Context context = {};
633         int r, n, fd;
634
635         log_parse_environment();
636         log_open();
637
638         r = parse_argv(argc, argv);
639         if (r <= 0)
640                 goto finish;
641
642         r = sd_event_default(&context.event);
643         if (r < 0) {
644                 log_error_errno(r, "Failed to allocate event loop: %m");
645                 goto finish;
646         }
647
648         r = sd_resolve_default(&context.resolve);
649         if (r < 0) {
650                 log_error_errno(r, "Failed to allocate resolver: %m");
651                 goto finish;
652         }
653
654         r = sd_resolve_attach_event(context.resolve, context.event, 0);
655         if (r < 0) {
656                 log_error_errno(r, "Failed to attach resolver: %m");
657                 goto finish;
658         }
659
660         sd_event_set_watchdog(context.event, true);
661
662         n = sd_listen_fds(1);
663         if (n < 0) {
664                 log_error("Failed to receive sockets from parent.");
665                 r = n;
666                 goto finish;
667         } else if (n == 0) {
668                 log_error("Didn't get any sockets passed in.");
669                 r = -EINVAL;
670                 goto finish;
671         }
672
673         for (fd = SD_LISTEN_FDS_START; fd < SD_LISTEN_FDS_START + n; fd++) {
674                 r = add_listen_socket(&context, fd);
675                 if (r < 0)
676                         goto finish;
677         }
678
679         r = sd_event_loop(context.event);
680         if (r < 0) {
681                 log_error_errno(r, "Failed to run event loop: %m");
682                 goto finish;
683         }
684
685 finish:
686         context_free(&context);
687
688         return r < 0 ? EXIT_FAILURE : EXIT_SUCCESS;
689 }