chiark / gitweb /
machine: make sure unpriviliged "machinectl status" can show the machine's OS version
[elogind.git] / src / socket-proxy / socket-proxyd.c
1 /*-*- Mode: C; c-basic-offset: 8; indent-tabs-mode: nil -*-*/
2
3 /***
4   This file is part of systemd.
5
6   Copyright 2013 David Strauss
7
8   systemd is free software; you can redistribute it and/or modify it
9   under the terms of the GNU Lesser General Public License as published by
10   the Free Software Foundation; either version 2.1 of the License, or
11   (at your option) any later version.
12
13   systemd is distributed in the hope that it will be useful, but
14   WITHOUT ANY WARRANTY; without even the implied warranty of
15   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
16   Lesser General Public License for more details.
17
18   You should have received a copy of the GNU Lesser General Public License
19   along with systemd; If not, see <http://www.gnu.org/licenses/>.
20  ***/
21
22 #include <arpa/inet.h>
23 #include <errno.h>
24 #include <getopt.h>
25 #include <stdio.h>
26 #include <stdlib.h>
27 #include <string.h>
28 #include <netdb.h>
29 #include <sys/fcntl.h>
30 #include <sys/socket.h>
31 #include <sys/un.h>
32 #include <unistd.h>
33
34 #include "sd-daemon.h"
35 #include "sd-event.h"
36 #include "sd-resolve.h"
37 #include "log.h"
38 #include "socket-util.h"
39 #include "util.h"
40 #include "event-util.h"
41 #include "build.h"
42 #include "set.h"
43 #include "path-util.h"
44
45 #define BUFFER_SIZE (256 * 1024)
46 #define CONNECTIONS_MAX 256
47
48 static const char *arg_remote_host = NULL;
49
50 typedef struct Context {
51         sd_event *event;
52         sd_resolve *resolve;
53
54         Set *listen;
55         Set *connections;
56 } Context;
57
58 typedef struct Connection {
59         Context *context;
60
61         int server_fd, client_fd;
62         int server_to_client_buffer[2]; /* a pipe */
63         int client_to_server_buffer[2]; /* a pipe */
64
65         size_t server_to_client_buffer_full, client_to_server_buffer_full;
66         size_t server_to_client_buffer_size, client_to_server_buffer_size;
67
68         sd_event_source *server_event_source, *client_event_source;
69
70         sd_resolve_query *resolve_query;
71 } Connection;
72
73 static void connection_free(Connection *c) {
74         assert(c);
75
76         if (c->context)
77                 set_remove(c->context->connections, c);
78
79         sd_event_source_unref(c->server_event_source);
80         sd_event_source_unref(c->client_event_source);
81
82         safe_close(c->server_fd);
83         safe_close(c->client_fd);
84
85         safe_close_pair(c->server_to_client_buffer);
86         safe_close_pair(c->client_to_server_buffer);
87
88         sd_resolve_query_unref(c->resolve_query);
89
90         free(c);
91 }
92
93 static void context_free(Context *context) {
94         sd_event_source *es;
95         Connection *c;
96
97         assert(context);
98
99         while ((es = set_steal_first(context->listen)))
100                 sd_event_source_unref(es);
101
102         while ((c = set_first(context->connections)))
103                 connection_free(c);
104
105         set_free(context->listen);
106         set_free(context->connections);
107
108         sd_event_unref(context->event);
109         sd_resolve_unref(context->resolve);
110 }
111
112 static int connection_create_pipes(Connection *c, int buffer[2], size_t *sz) {
113         int r;
114
115         assert(c);
116         assert(buffer);
117         assert(sz);
118
119         if (buffer[0] >= 0)
120                 return 0;
121
122         r = pipe2(buffer, O_CLOEXEC|O_NONBLOCK);
123         if (r < 0) {
124                 log_error("Failed to allocate pipe buffer: %m");
125                 return -errno;
126         }
127
128         fcntl(buffer[0], F_SETPIPE_SZ, BUFFER_SIZE);
129
130         r = fcntl(buffer[0], F_GETPIPE_SZ);
131         if (r < 0) {
132                 log_error("Failed to get pipe buffer size: %m");
133                 return -errno;
134         }
135
136         assert(r > 0);
137         *sz = r;
138
139         return 0;
140 }
141
142 static int connection_shovel(
143                 Connection *c,
144                 int *from, int buffer[2], int *to,
145                 size_t *full, size_t *sz,
146                 sd_event_source **from_source, sd_event_source **to_source) {
147
148         bool shoveled;
149
150         assert(c);
151         assert(from);
152         assert(buffer);
153         assert(buffer[0] >= 0);
154         assert(buffer[1] >= 0);
155         assert(to);
156         assert(full);
157         assert(sz);
158         assert(from_source);
159         assert(to_source);
160
161         do {
162                 ssize_t z;
163
164                 shoveled = false;
165
166                 if (*full < *sz && *from >= 0 && *to >= 0) {
167                         z = splice(*from, NULL, buffer[1], NULL, *sz - *full, SPLICE_F_MOVE|SPLICE_F_NONBLOCK);
168                         if (z > 0) {
169                                 *full += z;
170                                 shoveled = true;
171                         } else if (z == 0 || errno == EPIPE || errno == ECONNRESET) {
172                                 *from_source = sd_event_source_unref(*from_source);
173                                 *from = safe_close(*from);
174                         } else if (errno != EAGAIN && errno != EINTR) {
175                                 log_error("Failed to splice: %m");
176                                 return -errno;
177                         }
178                 }
179
180                 if (*full > 0 && *to >= 0) {
181                         z = splice(buffer[0], NULL, *to, NULL, *full, SPLICE_F_MOVE|SPLICE_F_NONBLOCK);
182                         if (z > 0) {
183                                 *full -= z;
184                                 shoveled = true;
185                         } else if (z == 0 || errno == EPIPE || errno == ECONNRESET) {
186                                 *to_source = sd_event_source_unref(*to_source);
187                                 *to = safe_close(*to);
188                         } else if (errno != EAGAIN && errno != EINTR) {
189                                 log_error("Failed to splice: %m");
190                                 return -errno;
191                         }
192                 }
193         } while (shoveled);
194
195         return 0;
196 }
197
198 static int connection_enable_event_sources(Connection *c);
199
200 static int traffic_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) {
201         Connection *c = userdata;
202         int r;
203
204         assert(s);
205         assert(fd >= 0);
206         assert(c);
207
208         r = connection_shovel(c,
209                               &c->server_fd, c->server_to_client_buffer, &c->client_fd,
210                               &c->server_to_client_buffer_full, &c->server_to_client_buffer_size,
211                               &c->server_event_source, &c->client_event_source);
212         if (r < 0)
213                 goto quit;
214
215         r = connection_shovel(c,
216                               &c->client_fd, c->client_to_server_buffer, &c->server_fd,
217                               &c->client_to_server_buffer_full, &c->client_to_server_buffer_size,
218                               &c->client_event_source, &c->server_event_source);
219         if (r < 0)
220                 goto quit;
221
222         /* EOF on both sides? */
223         if (c->server_fd == -1 && c->client_fd == -1)
224                 goto quit;
225
226         /* Server closed, and all data written to client? */
227         if (c->server_fd == -1 && c->server_to_client_buffer_full <= 0)
228                 goto quit;
229
230         /* Client closed, and all data written to server? */
231         if (c->client_fd == -1 && c->client_to_server_buffer_full <= 0)
232                 goto quit;
233
234         r = connection_enable_event_sources(c);
235         if (r < 0)
236                 goto quit;
237
238         return 1;
239
240 quit:
241         connection_free(c);
242         return 0; /* ignore errors, continue serving */
243 }
244
245 static int connection_enable_event_sources(Connection *c) {
246         uint32_t a = 0, b = 0;
247         int r;
248
249         assert(c);
250
251         if (c->server_to_client_buffer_full > 0)
252                 b |= EPOLLOUT;
253         if (c->server_to_client_buffer_full < c->server_to_client_buffer_size)
254                 a |= EPOLLIN;
255
256         if (c->client_to_server_buffer_full > 0)
257                 a |= EPOLLOUT;
258         if (c->client_to_server_buffer_full < c->client_to_server_buffer_size)
259                 b |= EPOLLIN;
260
261         if (c->server_event_source)
262                 r = sd_event_source_set_io_events(c->server_event_source, a);
263         else if (c->server_fd >= 0)
264                 r = sd_event_add_io(c->context->event, &c->server_event_source, c->server_fd, a, traffic_cb, c);
265         else
266                 r = 0;
267
268         if (r < 0) {
269                 log_error("Failed to set up server event source: %s", strerror(-r));
270                 return r;
271         }
272
273         if (c->client_event_source)
274                 r = sd_event_source_set_io_events(c->client_event_source, b);
275         else if (c->client_fd >= 0)
276                 r = sd_event_add_io(c->context->event, &c->client_event_source, c->client_fd, b, traffic_cb, c);
277         else
278                 r = 0;
279
280         if (r < 0) {
281                 log_error("Failed to set up client event source: %s", strerror(-r));
282                 return r;
283         }
284
285         return 0;
286 }
287
288 static int connection_complete(Connection *c) {
289         int r;
290
291         assert(c);
292
293         r = connection_create_pipes(c, c->server_to_client_buffer, &c->server_to_client_buffer_size);
294         if (r < 0)
295                 goto fail;
296
297         r = connection_create_pipes(c, c->client_to_server_buffer, &c->client_to_server_buffer_size);
298         if (r < 0)
299                 goto fail;
300
301         r = connection_enable_event_sources(c);
302         if (r < 0)
303                 goto fail;
304
305         return 0;
306
307 fail:
308         connection_free(c);
309         return 0; /* ignore errors, continue serving */
310 }
311
312 static int connect_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) {
313         Connection *c = userdata;
314         socklen_t solen;
315         int error, r;
316
317         assert(s);
318         assert(fd >= 0);
319         assert(c);
320
321         solen = sizeof(error);
322         r = getsockopt(fd, SOL_SOCKET, SO_ERROR, &error, &solen);
323         if (r < 0) {
324                 log_error("Failed to issue SO_ERROR: %m");
325                 goto fail;
326         }
327
328         if (error != 0) {
329                 log_error("Failed to connect to remote host: %s", strerror(error));
330                 goto fail;
331         }
332
333         c->client_event_source = sd_event_source_unref(c->client_event_source);
334
335         return connection_complete(c);
336
337 fail:
338         connection_free(c);
339         return 0; /* ignore errors, continue serving */
340 }
341
342 static int connection_start(Connection *c, struct sockaddr *sa, socklen_t salen) {
343         int r;
344
345         assert(c);
346         assert(sa);
347         assert(salen);
348
349         c->client_fd = socket(sa->sa_family, SOCK_STREAM|SOCK_NONBLOCK|SOCK_CLOEXEC, 0);
350         if (c->client_fd < 0) {
351                 log_error("Failed to get remote socket: %m");
352                 goto fail;
353         }
354
355         r = connect(c->client_fd, sa, salen);
356         if (r < 0) {
357                 if (errno == EINPROGRESS) {
358                         r = sd_event_add_io(c->context->event, &c->client_event_source, c->client_fd, EPOLLOUT, connect_cb, c);
359                         if (r < 0) {
360                                 log_error("Failed to add connection socket: %s", strerror(-r));
361                                 goto fail;
362                         }
363
364                         r = sd_event_source_set_enabled(c->client_event_source, SD_EVENT_ONESHOT);
365                         if (r < 0) {
366                                 log_error("Failed to enable oneshot event source: %s", strerror(-r));
367                                 goto fail;
368                         }
369                 } else {
370                         log_error("Failed to connect to remote host: %m");
371                         goto fail;
372                 }
373         } else {
374                 r = connection_complete(c);
375                 if (r < 0)
376                         goto fail;
377         }
378
379         return 0;
380
381 fail:
382         connection_free(c);
383         return 0; /* ignore errors, continue serving */
384 }
385
386 static int resolve_cb(sd_resolve_query *q, int ret, const struct addrinfo *ai, void *userdata) {
387         Connection *c = userdata;
388
389         assert(q);
390         assert(c);
391
392         if (ret != 0) {
393                 log_error("Failed to resolve host: %s", gai_strerror(ret));
394                 goto fail;
395         }
396
397         c->resolve_query = sd_resolve_query_unref(c->resolve_query);
398
399         return connection_start(c, ai->ai_addr, ai->ai_addrlen);
400
401 fail:
402         connection_free(c);
403         return 0; /* ignore errors, continue serving */
404 }
405
406 static int resolve_remote(Connection *c) {
407
408         static const struct addrinfo hints = {
409                 .ai_family = AF_UNSPEC,
410                 .ai_socktype = SOCK_STREAM,
411                 .ai_flags = AI_ADDRCONFIG
412         };
413
414         union sockaddr_union sa = {};
415         const char *node, *service;
416         socklen_t salen;
417         int r;
418
419         if (path_is_absolute(arg_remote_host)) {
420                 sa.un.sun_family = AF_UNIX;
421                 strncpy(sa.un.sun_path, arg_remote_host, sizeof(sa.un.sun_path)-1);
422                 sa.un.sun_path[sizeof(sa.un.sun_path)-1] = 0;
423
424                 salen = offsetof(union sockaddr_union, un.sun_path) + strlen(sa.un.sun_path);
425
426                 return connection_start(c, &sa.sa, salen);
427         }
428
429         if (arg_remote_host[0] == '@') {
430                 sa.un.sun_family = AF_UNIX;
431                 sa.un.sun_path[0] = 0;
432                 strncpy(sa.un.sun_path+1, arg_remote_host+1, sizeof(sa.un.sun_path)-2);
433                 sa.un.sun_path[sizeof(sa.un.sun_path)-1] = 0;
434
435                 salen = offsetof(union sockaddr_union, un.sun_path) + 1 + strlen(sa.un.sun_path + 1);
436
437                 return connection_start(c, &sa.sa, salen);
438         }
439
440         service = strrchr(arg_remote_host, ':');
441         if (service) {
442                 node = strndupa(arg_remote_host, service - arg_remote_host);
443                 service ++;
444         } else {
445                 node = arg_remote_host;
446                 service = "80";
447         }
448
449         log_debug("Looking up address info for %s:%s", node, service);
450         r = sd_resolve_getaddrinfo(c->context->resolve, &c->resolve_query, node, service, &hints, resolve_cb, c);
451         if (r < 0) {
452                 log_error("Failed to resolve remote host: %s", strerror(-r));
453                 goto fail;
454         }
455
456         return 0;
457
458 fail:
459         connection_free(c);
460         return 0; /* ignore errors, continue serving */
461 }
462
463 static int add_connection_socket(Context *context, int fd) {
464         Connection *c;
465         int r;
466
467         assert(context);
468         assert(fd >= 0);
469
470         if (set_size(context->connections) > CONNECTIONS_MAX) {
471                 log_warning("Hit connection limit, refusing connection.");
472                 safe_close(fd);
473                 return 0;
474         }
475
476         r = set_ensure_allocated(&context->connections, trivial_hash_func, trivial_compare_func);
477         if (r < 0) {
478                 log_oom();
479                 return 0;
480         }
481
482         c = new0(Connection, 1);
483         if (!c) {
484                 log_oom();
485                 return 0;
486         }
487
488         c->context = context;
489         c->server_fd = fd;
490         c->client_fd = -1;
491         c->server_to_client_buffer[0] = c->server_to_client_buffer[1] = -1;
492         c->client_to_server_buffer[0] = c->client_to_server_buffer[1] = -1;
493
494         r = set_put(context->connections, c);
495         if (r < 0) {
496                 free(c);
497                 log_oom();
498                 return 0;
499         }
500
501         return resolve_remote(c);
502 }
503
504 static int accept_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) {
505         _cleanup_free_ char *peer = NULL;
506         Context *context = userdata;
507         int nfd = -1, r;
508
509         assert(s);
510         assert(fd >= 0);
511         assert(revents & EPOLLIN);
512         assert(context);
513
514         nfd = accept4(fd, NULL, NULL, SOCK_NONBLOCK|SOCK_CLOEXEC);
515         if (nfd < 0) {
516                 if (errno != -EAGAIN)
517                         log_warning("Failed to accept() socket: %m");
518         } else {
519                 getpeername_pretty(nfd, &peer);
520                 log_debug("New connection from %s", strna(peer));
521
522                 r = add_connection_socket(context, nfd);
523                 if (r < 0) {
524                         log_error("Failed to accept connection, ignoring: %s", strerror(-r));
525                         safe_close(fd);
526                 }
527         }
528
529         r = sd_event_source_set_enabled(s, SD_EVENT_ONESHOT);
530         if (r < 0) {
531                 log_error("Error while re-enabling listener with ONESHOT: %s", strerror(-r));
532                 sd_event_exit(context->event, r);
533                 return r;
534         }
535
536         return 1;
537 }
538
539 static int add_listen_socket(Context *context, int fd) {
540         sd_event_source *source;
541         int r;
542
543         assert(context);
544         assert(fd >= 0);
545
546         r = set_ensure_allocated(&context->listen, trivial_hash_func, trivial_compare_func);
547         if (r < 0) {
548                 log_oom();
549                 return r;
550         }
551
552         r = sd_is_socket(fd, 0, SOCK_STREAM, 1);
553         if (r < 0) {
554                 log_error("Failed to determine socket type: %s", strerror(-r));
555                 return r;
556         }
557         if (r == 0) {
558                 log_error("Passed in socket is not a stream socket.");
559                 return -EINVAL;
560         }
561
562         r = fd_nonblock(fd, true);
563         if (r < 0) {
564                 log_error("Failed to mark file descriptor non-blocking: %s", strerror(-r));
565                 return r;
566         }
567
568         r = sd_event_add_io(context->event, &source, fd, EPOLLIN, accept_cb, context);
569         if (r < 0) {
570                 log_error("Failed to add event source: %s", strerror(-r));
571                 return r;
572         }
573
574         r = set_put(context->listen, source);
575         if (r < 0) {
576                 log_error("Failed to add source to set: %s", strerror(-r));
577                 sd_event_source_unref(source);
578                 return r;
579         }
580
581         /* Set the watcher to oneshot in case other processes are also
582          * watching to accept(). */
583         r = sd_event_source_set_enabled(source, SD_EVENT_ONESHOT);
584         if (r < 0) {
585                 log_error("Failed to enable oneshot mode: %s", strerror(-r));
586                 return r;
587         }
588
589         return 0;
590 }
591
592 static void help(void) {
593         printf("%1$s [HOST:PORT]\n"
594                "%1$s [SOCKET]\n\n"
595                "Bidirectionally proxy local sockets to another (possibly remote) socket.\n\n"
596                "  -h --help              Show this help\n"
597                "     --version           Show package version\n",
598                program_invocation_short_name);
599 }
600
601 static int parse_argv(int argc, char *argv[]) {
602
603         enum {
604                 ARG_VERSION = 0x100,
605                 ARG_IGNORE_ENV
606         };
607
608         static const struct option options[] = {
609                 { "help",       no_argument, NULL, 'h'           },
610                 { "version",    no_argument, NULL, ARG_VERSION   },
611                 {}
612         };
613
614         int c;
615
616         assert(argc >= 0);
617         assert(argv);
618
619         while ((c = getopt_long(argc, argv, "h", options, NULL)) >= 0)
620
621                 switch (c) {
622
623                 case 'h':
624                         help();
625                         return 0;
626
627                 case ARG_VERSION:
628                         puts(PACKAGE_STRING);
629                         puts(SYSTEMD_FEATURES);
630                         return 0;
631
632                 case '?':
633                         return -EINVAL;
634
635                 default:
636                         assert_not_reached("Unhandled option");
637                 }
638
639         if (optind >= argc) {
640                 log_error("Not enough parameters.");
641                 return -EINVAL;
642         }
643
644         if (argc != optind+1) {
645                 log_error("Too many parameters.");
646                 return -EINVAL;
647         }
648
649         arg_remote_host = argv[optind];
650         return 1;
651 }
652
653 int main(int argc, char *argv[]) {
654         Context context = {};
655         int r, n, fd;
656
657         log_parse_environment();
658         log_open();
659
660         r = parse_argv(argc, argv);
661         if (r <= 0)
662                 goto finish;
663
664         r = sd_event_default(&context.event);
665         if (r < 0) {
666                 log_error("Failed to allocate event loop: %s", strerror(-r));
667                 goto finish;
668         }
669
670         r = sd_resolve_default(&context.resolve);
671         if (r < 0) {
672                 log_error("Failed to allocate resolver: %s", strerror(-r));
673                 goto finish;
674         }
675
676         r = sd_resolve_attach_event(context.resolve, context.event, 0);
677         if (r < 0) {
678                 log_error("Failed to attach resolver: %s", strerror(-r));
679                 goto finish;
680         }
681
682         sd_event_set_watchdog(context.event, true);
683
684         n = sd_listen_fds(1);
685         if (n < 0) {
686                 log_error("Failed to receive sockets from parent.");
687                 r = n;
688                 goto finish;
689         } else if (n == 0) {
690                 log_error("Didn't get any sockets passed in.");
691                 r = -EINVAL;
692                 goto finish;
693         }
694
695         for (fd = SD_LISTEN_FDS_START; fd < SD_LISTEN_FDS_START + n; fd++) {
696                 r = add_listen_socket(&context, fd);
697                 if (r < 0)
698                         goto finish;
699         }
700
701         r = sd_event_loop(context.event);
702         if (r < 0) {
703                 log_error("Failed to run event loop: %s", strerror(-r));
704                 goto finish;
705         }
706
707 finish:
708         context_free(&context);
709
710         return r < 0 ? EXIT_FAILURE : EXIT_SUCCESS;
711 }