chiark / gitweb /
Update TODOs with follow-up sabridge work.
[elogind.git] / src / sabridge / sabridge.c
1 /*-*- Mode: C; c-basic-offset: 8; indent-tabs-mode: nil -*-*/
2
3 /***
4   This file is part of systemd.
5
6   Copyright 2013 David Strauss
7
8   systemd is free software; you can redistribute it and/or modify it
9   under the terms of the GNU Lesser General Public License as published by
10   the Free Software Foundation; either version 2.1 of the License, or
11   (at your option) any later version.
12
13   systemd is distributed in the hope that it will be useful, but
14   WITHOUT ANY WARRANTY; without even the implied warranty of
15   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
16   Lesser General Public License for more details.
17
18   You should have received a copy of the GNU Lesser General Public License
19   along with systemd; If not, see <http://www.gnu.org/licenses/>.
20  ***/
21
22 #include <arpa/inet.h>
23 #include <errno.h>
24 #include <getopt.h>
25 #include <stdio.h>
26 #include <stdlib.h>
27 #include <string.h>
28 #include <netdb.h>
29 #include <sys/fcntl.h>
30 #include <sys/socket.h>
31 #include <sys/un.h>
32 #include <unistd.h>
33
34 #include "log.h"
35 #include "sd-daemon.h"
36 #include "sd-event.h"
37 #include "socket-util.h"
38 #include "util.h"
39
40 #define BUFFER_SIZE 4096
41 #define _cleanup_freeaddrinfo_ _cleanup_(freeaddrinfop)
42
43 unsigned int total_clients = 0;
44
45 DEFINE_TRIVIAL_CLEANUP_FUNC(struct addrinfo *, freeaddrinfo);
46
47 struct proxy {
48         int listen_fd;
49         bool ignore_env;
50         bool remote_is_inet;
51         const char *remote_host;
52         const char *remote_service;
53 };
54
55 struct connection {
56         int fd;
57         sd_event_source *w_recv;
58         sd_event_source *w_send;
59         struct connection *c_destination;
60         size_t buffer_filled_len;
61         size_t buffer_sent_len;
62         char buffer[BUFFER_SIZE];
63 };
64
65 static void free_connection(struct connection *c) {
66         log_debug("Freeing fd=%d (conn %p).", c->fd, c);
67         sd_event_source_unref(c->w_recv);
68         sd_event_source_unref(c->w_send);
69         close(c->fd);
70         free(c);
71 }
72
73 static int send_buffer(struct connection *sender) {
74         struct connection *receiver = sender->c_destination;
75         ssize_t len;
76         int r = 0;
77
78         /* We cannot assume that even a partial send() indicates that
79          * the next send() will block. Loop until it does. */
80         while (sender->buffer_filled_len > sender->buffer_sent_len) {
81                 len = send(receiver->fd, sender->buffer + sender->buffer_sent_len, sender->buffer_filled_len - sender->buffer_sent_len, 0);
82                 log_debug("send(%d, ...)=%ld", receiver->fd, len);
83                 if (len < 0) {
84                         if (errno != EWOULDBLOCK && errno != EAGAIN) {
85                                 log_error("Error %d in send to fd=%d: %m", errno, receiver->fd);
86                                 return -errno;
87                         }
88                         else {
89                                 /* send() is in a blocking state. */
90                                 break;
91                         }
92                 }
93
94                 /* len < 0 can't occur here. len == 0 is possible but
95                  * undefined behavior for nonblocking send(). */
96                 assert(len > 0);
97                 sender->buffer_sent_len += len;
98         }
99
100         log_debug("send(%d, ...) completed with %lu bytes still buffered.", receiver->fd, sender->buffer_filled_len - sender->buffer_sent_len);
101
102         /* Detect a would-block state or partial send. */
103         if (sender->buffer_filled_len > sender->buffer_sent_len) {
104
105                 /* If the buffer is full, disable events coming for recv. */
106                 if (sender->buffer_filled_len == BUFFER_SIZE) {
107                     r = sd_event_source_set_enabled(sender->w_recv, SD_EVENT_OFF);
108                     if (r < 0) {
109                             log_error("Error %d disabling recv for fd=%d: %s", r, sender->fd, strerror(-r));
110                             return r;
111                     }
112                 }
113
114                 /* Watch for when the recipient can be sent data again. */
115                 r = sd_event_source_set_enabled(receiver->w_send, SD_EVENT_ON);
116                 if (r < 0) {
117                         log_error("Error %d enabling send for fd=%d: %s", r, receiver->fd, strerror(-r));
118                         return r;
119                 }
120                 log_debug("Done with recv for fd=%d. Waiting on send for fd=%d.", sender->fd, receiver->fd);
121                 return r;
122         }
123
124         /* If we sent everything without blocking, the buffer is now empty. */
125         sender->buffer_filled_len = 0;
126         sender->buffer_sent_len = 0;
127
128         /* Unmute the sender, in case the buffer was full. */
129         r = sd_event_source_set_enabled(sender->w_recv, SD_EVENT_ON);
130         if (r < 0) {
131                 log_error("Error %d enabling recv for fd=%d: %s", r, sender->fd, strerror(-r));
132                 return r;
133         }
134
135         /* Mute the recipient, as we have no data to send now. */
136         r = sd_event_source_set_enabled(receiver->w_send, SD_EVENT_OFF);
137         if (r < 0) {
138                 log_error("Error %d disabling send for fd=%d: %s", r, receiver->fd, strerror(-r));
139                 return r;
140         }
141
142         return 0;
143 }
144
145 static int transfer_data_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) {
146         struct connection *c = (struct connection *) userdata;
147         int r = 0;
148         ssize_t len;
149
150         assert(revents & (EPOLLIN | EPOLLOUT));
151         assert(fd == c->fd);
152
153         log_debug("Got event revents=%d from fd=%d (conn %p).", revents, fd, c);
154
155         if (revents & EPOLLIN) {
156                 log_debug("About to recv up to %lu bytes from fd=%d (%lu/BUFFER_SIZE).", BUFFER_SIZE - c->buffer_filled_len, fd, c->buffer_filled_len);
157
158                 assert(s == c->w_recv);
159
160                 /* Receive until the buffer's full, there's no more data,
161                  * or the client/server disconnects. */
162                 while (c->buffer_filled_len < BUFFER_SIZE) {
163                         len = recv(fd, c->buffer + c->buffer_filled_len, BUFFER_SIZE - c->buffer_filled_len, 0);
164                         log_debug("recv(%d, ...)=%ld", fd, len);
165                         if (len < 0) {
166                                 if (errno != EWOULDBLOCK && errno != EAGAIN) {
167                                         log_error("Error %d in recv from fd=%d: %m", errno, fd);
168                                         return -errno;
169                                 }
170                                 else {
171                                         /* recv() is in a blocking state. */
172                                         break;
173                                 }
174                         }
175                         else if (len == 0) {
176                                 log_debug("Clean disconnection from fd=%d", fd);
177                                 total_clients--;
178                                 free_connection(c->c_destination);
179                                 free_connection(c);
180                                 return 0;
181                         }
182
183                         assert(len > 0);
184                         log_debug("Recording that the buffer got %ld more bytes full.", len);
185                         c->buffer_filled_len += len;
186                         log_debug("Buffer now has %ld bytes full.", c->buffer_filled_len);
187                 }
188
189                 /* Try sending the data immediately. */
190                 return send_buffer(c);
191         }
192         else {
193                 assert(s == c->w_send);
194                 return send_buffer(c->c_destination);
195         }
196
197         return r;
198 }
199
200 /* Once sending to the server is unblocked, set up the real watchers. */
201 static int connected_to_server_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) {
202         struct sd_event *e = NULL;
203         struct connection *c_server_to_client = (struct connection *) userdata;
204         struct connection *c_client_to_server = c_server_to_client->c_destination;
205         int r;
206
207         assert(revents & EPOLLOUT);
208
209         e = sd_event_get(s);
210
211         /* Cancel the initial write watcher for the server. */
212         sd_event_source_unref(s);
213
214         log_debug("Connected to server. Initializing watchers for receiving data.");
215
216         /* A disabled send watcher for the server. */
217         r = sd_event_add_io(e, c_server_to_client->fd, EPOLLOUT, transfer_data_cb, c_server_to_client, &c_server_to_client->w_send);
218         if (r < 0) {
219                 log_error("Error %d creating send watcher for fd=%d: %s", r, c_server_to_client->fd, strerror(-r));
220                 goto fail;
221         }
222         r = sd_event_source_set_enabled(c_server_to_client->w_send, SD_EVENT_OFF);
223         if (r < 0) {
224                 log_error("Error %d muting send for fd=%d: %s", r, c_server_to_client->fd, strerror(-r));
225                 goto finish;
226         }
227
228         /* A recv watcher for the server. */
229         r = sd_event_add_io(e, c_server_to_client->fd, EPOLLIN, transfer_data_cb, c_server_to_client, &c_server_to_client->w_recv);
230         if (r < 0) {
231                 log_error("Error %d creating recv watcher for fd=%d: %s", r, c_server_to_client->fd, strerror(-r));
232                 goto fail;
233         }
234
235         /* A disabled send watcher for the client. */
236         r = sd_event_add_io(e, c_client_to_server->fd, EPOLLOUT, transfer_data_cb, c_client_to_server, &c_client_to_server->w_send);
237         if (r < 0) {
238                 log_error("Error %d creating send watcher for fd=%d: %s", r, c_client_to_server->fd, strerror(-r));
239                 goto fail;
240         }
241         r = sd_event_source_set_enabled(c_client_to_server->w_send, SD_EVENT_OFF);
242         if (r < 0) {
243                 log_error("Error %d muting send for fd=%d: %s", r, c_client_to_server->fd, strerror(-r));
244                 goto finish;
245         }
246
247         /* A recv watcher for the client. */
248         r = sd_event_add_io(e, c_client_to_server->fd, EPOLLIN, transfer_data_cb, c_client_to_server, &c_client_to_server->w_recv);
249         if (r < 0) {
250                 log_error("Error %d creating recv watcher for fd=%d: %s", r, c_client_to_server->fd, strerror(-r));
251                 goto fail;
252         }
253
254 goto finish;
255
256 fail:
257         free_connection(c_client_to_server);
258         free_connection(c_server_to_client);
259
260 finish:
261         return r;
262 }
263
264 static int get_server_connection_fd(const struct proxy *proxy) {
265         int server_fd;
266         int r = -EBADF;
267         int len;
268
269         if (proxy->remote_is_inet) {
270                 int s;
271                 _cleanup_freeaddrinfo_ struct addrinfo *result = NULL;
272                 struct addrinfo hints = {.ai_family = AF_UNSPEC,
273                                          .ai_socktype = SOCK_STREAM,
274                                          .ai_flags = AI_PASSIVE};
275
276                 log_debug("Looking up address info for %s:%s", proxy->remote_host, proxy->remote_service);
277                 s = getaddrinfo(proxy->remote_host, proxy->remote_service, &hints, &result);
278                 if (s != 0) {
279                         log_error("getaddrinfo error (%d): %s", s, gai_strerror(s));
280                         return r;
281                 }
282
283                 if (result == NULL) {
284                         log_error("getaddrinfo: no result");
285                         return r;
286                 }
287
288                 /* @TODO: Try connecting to all results instead of just the first. */
289                 server_fd = socket(result->ai_family, result->ai_socktype | SOCK_NONBLOCK, result->ai_protocol);
290                 if (server_fd < 0) {
291                         log_error("Error %d creating socket: %m", errno);
292                         return r;
293                 }
294
295                 r = connect(server_fd, result->ai_addr, result->ai_addrlen);
296                 /* Ignore EINPROGRESS errors because they're expected for a nonblocking socket. */
297                 if (r < 0 && errno != EINPROGRESS) {
298                         log_error("Error %d while connecting to socket %s:%s: %m", errno, proxy->remote_host, proxy->remote_service);
299                         return r;
300                 }
301         }
302         else {
303                 struct sockaddr_un remote;
304
305                 server_fd = socket(AF_UNIX, SOCK_STREAM | SOCK_NONBLOCK, 0);
306                 if (server_fd < 0) {
307                         log_error("Error %d creating socket: %m", errno);
308                         return -EBADFD;
309                 }
310
311                 remote.sun_family = AF_UNIX;
312                 strncpy(remote.sun_path, proxy->remote_host, sizeof(remote.sun_path));
313                 len = strlen(remote.sun_path) + sizeof(remote.sun_family);
314                 r = connect(server_fd, (struct sockaddr *) &remote, len);
315                 if (r < 0 && errno != EINPROGRESS) {
316                         log_error("Error %d while connecting to Unix domain socket %s: %m", errno, proxy->remote_host);
317                         return -EBADFD;
318                 }
319         }
320
321         log_debug("Server connection is fd=%d", server_fd);
322         return server_fd;
323 }
324
325 static int accept_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) {
326         struct proxy *proxy = (struct proxy *) userdata;
327         struct connection *c_server_to_client;
328         struct connection *c_client_to_server;
329         int r = 0;
330         union sockaddr_union sa;
331         socklen_t salen = sizeof(sa);
332
333         assert(revents & EPOLLIN);
334
335         c_server_to_client = malloc0(sizeof(struct connection));
336         if (c_server_to_client == NULL) {
337                 log_oom();
338                 goto fail;
339         }
340
341         c_client_to_server = malloc0(sizeof(struct connection));
342         if (c_client_to_server == NULL) {
343                 log_oom();
344                 goto fail;
345         }
346
347         c_server_to_client->fd = get_server_connection_fd(proxy);
348         if (c_server_to_client->fd < 0) {
349                 log_error("Error initiating server connection.");
350                 goto fail;
351         }
352
353         c_client_to_server->fd = accept(fd, (struct sockaddr *) &sa, &salen);
354         if (c_client_to_server->fd < 0) {
355                 log_error("Error accepting client connection.");
356                 goto fail;
357         }
358
359         /* Unlike on BSD, client sockets do not inherit nonblocking status
360          * from the listening socket. */
361         r = fd_nonblock(c_client_to_server->fd, true);
362         if (r < 0) {
363                 log_error("Error %d marking client connection as nonblocking: %s", r, strerror(-r));
364                 goto fail;
365         }
366
367         if (sa.sa.sa_family == AF_INET || sa.sa.sa_family == AF_INET6) {
368                 char sa_str[INET6_ADDRSTRLEN];
369                 const char *success;
370
371                 success = inet_ntop(sa.sa.sa_family, &sa.in6.sin6_addr, sa_str, INET6_ADDRSTRLEN);
372                 if (success == NULL)
373                         log_warning("Error %d calling inet_ntop: %m", errno);
374                 else
375                         log_debug("Accepted client connection from %s as fd=%d", sa_str, c_client_to_server->fd);
376         }
377         else {
378                 log_debug("Accepted client connection (non-IP) as fd=%d", c_client_to_server->fd);
379         }
380
381         total_clients++;
382         log_debug("Client fd=%d (conn %p) successfully connected. Total clients: %u", c_client_to_server->fd, c_client_to_server, total_clients);
383         log_debug("Server fd=%d (conn %p) successfully initialized.", c_server_to_client->fd, c_server_to_client);
384
385         /* Initialize watcher for send to server; this shows connectivity. */
386         r = sd_event_add_io(sd_event_get(s), c_server_to_client->fd, EPOLLOUT, connected_to_server_cb, c_server_to_client, &c_server_to_client->w_send);
387         if (r < 0) {
388                 log_error("Error %d creating connectivity watcher for fd=%d: %s", r, c_server_to_client->fd, strerror(-r));
389                 goto fail;
390         }
391
392         /* Allow lookups of the opposite connection. */
393         c_server_to_client->c_destination = c_client_to_server;
394         c_client_to_server->c_destination = c_server_to_client;
395
396         goto finish;
397
398 fail:
399         log_warning("Accepting a client connection or connecting to the server failed.");
400         free_connection(c_client_to_server);
401         free_connection(c_server_to_client);
402
403 finish:
404         /* Preserve the main loop even if a single proxy setup fails. */
405         return 0;
406 }
407
408 static int run_main_loop(struct proxy *proxy) {
409         int r = EXIT_SUCCESS;
410         struct sd_event *e = NULL;
411         sd_event_source *w_accept = NULL;
412
413         r = sd_event_new(&e);
414         if (r < 0)
415                 goto finish;
416
417         r = fd_nonblock(proxy->listen_fd, true);
418         if (r < 0)
419                 goto finish;
420
421         log_debug("Initializing main listener fd=%d", proxy->listen_fd);
422
423         sd_event_add_io(e, proxy->listen_fd, EPOLLIN, accept_cb, proxy, &w_accept);
424
425         log_debug("Initialized main listener. Entering loop.");
426
427         sd_event_loop(e);
428
429 finish:
430         sd_event_source_unref(w_accept);
431         sd_event_unref(e);
432
433         return r;
434 }
435
436 static int help(void) {
437
438         printf("%s hostname-or-ip port-or-service\n"
439                "%s unix-domain-socket-path\n\n"
440                "Inherit a socket. Bidirectionally proxy.\n\n"
441                "  -h --help       Show this help\n"
442                "  --version       Print version and exit\n"
443                "  --ignore-env    Ignore expected systemd environment\n",
444                program_invocation_short_name,
445                program_invocation_short_name);
446
447         return 0;
448 }
449
450 static void version(void) {
451         puts(PACKAGE_STRING " sabridge");
452 }
453
454 static int parse_argv(int argc, char *argv[], struct proxy *p) {
455
456         enum {
457                 ARG_VERSION = 0x100,
458                 ARG_IGNORE_ENV
459         };
460
461         static const struct option options[] = {
462                 { "help",       no_argument, NULL, 'h'           },
463                 { "version",    no_argument, NULL, ARG_VERSION   },
464                 { "ignore-env", no_argument, NULL, ARG_IGNORE_ENV},
465                 { NULL,         0,           NULL, 0             }
466         };
467
468         int c;
469
470         assert(argc >= 0);
471         assert(argv);
472
473         while ((c = getopt_long(argc, argv, "h", options, NULL)) >= 0) {
474
475                 switch (c) {
476
477                 case 'h':
478                         help();
479                         return 0;
480
481                 case '?':
482                         return -EINVAL;
483
484                 case ARG_VERSION:
485                         version();
486                         return 0;
487
488                 case ARG_IGNORE_ENV:
489                         p->ignore_env = true;
490                         continue;
491
492                 default:
493                         log_error("Unknown option code %c", c);
494                         return -EINVAL;
495                 }
496         }
497
498         if (optind + 1 != argc && optind + 2 != argc) {
499                 log_error("Incorrect number of positional arguments.");
500                 help();
501                 return -EINVAL;
502         }
503
504         p->remote_host = argv[optind];
505         assert(p->remote_host);
506
507         p->remote_is_inet = p->remote_host[0] != '/';
508
509         if (optind == argc - 2) {
510                 if (!p->remote_is_inet) {
511                         log_error("A port or service is not allowed for Unix socket destinations.");
512                         help();
513                         return -EINVAL;
514                 }
515                 p->remote_service = argv[optind + 1];
516                 assert(p->remote_service);
517         } else if (p->remote_is_inet) {
518                 log_error("A port or service is required for IP destinations.");
519                 help();
520                 return -EINVAL;
521         }
522
523         return 1;
524 }
525
526 int main(int argc, char *argv[]) {
527         struct proxy p = {};
528         int r;
529
530         log_parse_environment();
531         log_open();
532
533         r = parse_argv(argc, argv, &p);
534         if (r <= 0)
535                 goto finish;
536
537         p.listen_fd = SD_LISTEN_FDS_START;
538
539         if (!p.ignore_env) {
540             int n;
541             n = sd_listen_fds(1);
542             if (n == 0) {
543                     log_error("Found zero inheritable sockets. Are you sure this is running as a socket-activated service?");
544                     r = EXIT_FAILURE;
545                     goto finish;
546             } else if (n < 0) {
547                     log_error("Error %d while finding inheritable sockets: %s", n, strerror(-n));
548                     r = EXIT_FAILURE;
549                     goto finish;
550             } else if (n > 1) {
551                     log_error("Can't listen on more than one socket.");
552                     r = EXIT_FAILURE;
553                     goto finish;
554             }
555         }
556
557         /* @TODO: Check if this proxy can work with datagram sockets. */
558         r = sd_is_socket(p.listen_fd, 0, SOCK_STREAM, 1);
559         if (r < 0) {
560                 log_error("Error %d while checking inherited socket: %s", r, strerror(-r));
561                 goto finish;
562         }
563
564         log_info("Starting the socket activation bridge with listener fd=%d.", p.listen_fd);
565
566         r = run_main_loop(&p);
567         if (r < 0) {
568                 log_error("Error %d from main loop.", r);
569                 goto finish;
570         }
571
572 finish:
573         log_close();
574         return r < 0 ? EXIT_FAILURE : EXIT_SUCCESS;
575 }