chiark / gitweb /
bus: make sure sd_bus_get_timeout() returns a 0 timeout of there are already read...
[elogind.git] / src / saproxy / saproxy.c
1 /*-*- Mode: C; c-basic-offset: 8; indent-tabs-mode: nil -*-*/
2
3 /***
4   This file is part of systemd.
5
6   Copyright 2013 David Strauss
7
8   systemd is free software; you can redistribute it and/or modify it
9   under the terms of the GNU Lesser General Public License as published by
10   the Free Software Foundation; either version 2.1 of the License, or
11   (at your option) any later version.
12
13   systemd is distributed in the hope that it will be useful, but
14   WITHOUT ANY WARRANTY; without even the implied warranty of
15   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
16   Lesser General Public License for more details.
17
18   You should have received a copy of the GNU Lesser General Public License
19   along with systemd; If not, see <http://www.gnu.org/licenses/>.
20  ***/
21
22 #include <arpa/inet.h>
23 #include <errno.h>
24 #include <getopt.h>
25 #include <stdio.h>
26 #include <stdlib.h>
27 #include <string.h>
28 #include <netdb.h>
29 #include <sys/fcntl.h>
30 #include <sys/socket.h>
31 #include <sys/un.h>
32 #include <unistd.h>
33
34 #include "log.h"
35 #include "sd-daemon.h"
36 #include "sd-event.h"
37 #include "socket-util.h"
38 #include "util.h"
39
40 #define BUFFER_SIZE 16384
41 #define _cleanup_freeaddrinfo_ _cleanup_(freeaddrinfop)
42
43 unsigned int total_clients = 0;
44
45 DEFINE_TRIVIAL_CLEANUP_FUNC(struct addrinfo *, freeaddrinfo);
46
47 struct proxy {
48         int listen_fd;
49         bool ignore_env;
50         bool remote_is_inet;
51         const char *remote_host;
52         const char *remote_service;
53 };
54
55 struct connection {
56         int fd;
57         uint32_t events;
58         sd_event_source *w;
59         struct connection *c_destination;
60         size_t buffer_filled_len;
61         size_t buffer_sent_len;
62         char buffer[BUFFER_SIZE];
63 };
64
65 static void free_connection(struct connection *c) {
66         log_debug("Freeing fd=%d (conn %p).", c->fd, c);
67         sd_event_source_unref(c->w);
68         close(c->fd);
69         free(c);
70 }
71
72 static int add_event_to_connection(struct connection *c, uint32_t events) {
73         int r;
74
75         log_debug("Have revents=%d. Adding revents=%d.", c->events, events);
76
77         c->events |= events;
78
79         r = sd_event_source_set_io_events(c->w, c->events);
80         if (r < 0) {
81                 log_error("Error %d setting revents: %s", r, strerror(-r));
82                 return r;
83         }
84
85         r = sd_event_source_set_enabled(c->w, SD_EVENT_ON);
86         if (r < 0) {
87                 log_error("Error %d enabling source: %s", r, strerror(-r));
88                 return r;
89         }
90
91         return 0;
92 }
93
94 static int remove_event_from_connection(struct connection *c, uint32_t events) {
95         int r;
96
97         log_debug("Have revents=%d. Removing revents=%d.", c->events, events);
98
99         c->events &= ~events;
100
101         r = sd_event_source_set_io_events(c->w, c->events);
102         if (r < 0) {
103                 log_error("Error %d setting revents: %s", r, strerror(-r));
104                 return r;
105         }
106
107         if (c->events == 0) {
108             r = sd_event_source_set_enabled(c->w, SD_EVENT_OFF);
109             if (r < 0) {
110                     log_error("Error %d disabling source: %s", r, strerror(-r));
111                     return r;
112             }
113         }
114
115         return 0;
116 }
117
118 static int send_buffer(struct connection *sender) {
119         struct connection *receiver = sender->c_destination;
120         ssize_t len;
121         int r = 0;
122
123         /* We cannot assume that even a partial send() indicates that
124          * the next send() will block. Loop until it does. */
125         while (sender->buffer_filled_len > sender->buffer_sent_len) {
126                 len = send(receiver->fd, sender->buffer + sender->buffer_sent_len, sender->buffer_filled_len - sender->buffer_sent_len, 0);
127                 log_debug("send(%d, ...)=%ld", receiver->fd, len);
128                 if (len < 0) {
129                         if (errno != EWOULDBLOCK && errno != EAGAIN) {
130                                 log_error("Error %d in send to fd=%d: %m", errno, receiver->fd);
131                                 return -errno;
132                         }
133                         else {
134                                 /* send() is in a blocking state. */
135                                 break;
136                         }
137                 }
138
139                 /* len < 0 can't occur here. len == 0 is possible but
140                  * undefined behavior for nonblocking send(). */
141                 assert(len > 0);
142                 sender->buffer_sent_len += len;
143         }
144
145         log_debug("send(%d, ...) completed with %lu bytes still buffered.", receiver->fd, sender->buffer_filled_len - sender->buffer_sent_len);
146
147         /* Detect a would-block state or partial send. */
148         if (sender->buffer_filled_len > sender->buffer_sent_len) {
149
150                 /* If the buffer is full, disable events coming for recv. */
151                 if (sender->buffer_filled_len == BUFFER_SIZE) {
152                     r = remove_event_from_connection(sender, EPOLLIN);
153                     if (r < 0) {
154                             log_error("Error %d disabling EPOLLIN for fd=%d: %s", r, sender->fd, strerror(-r));
155                             return r;
156                     }
157                 }
158
159                 /* Watch for when the recipient can be sent data again. */
160                 r = add_event_to_connection(receiver, EPOLLOUT);
161                 if (r < 0) {
162                         log_error("Error %d enabling EPOLLOUT for fd=%d: %s", r, receiver->fd, strerror(-r));
163                         return r;
164                 }
165                 log_debug("Done with recv for fd=%d. Waiting on send for fd=%d.", sender->fd, receiver->fd);
166                 return r;
167         }
168
169         /* If we sent everything without blocking, the buffer is now empty. */
170         sender->buffer_filled_len = 0;
171         sender->buffer_sent_len = 0;
172
173         /* Enable the sender's receive watcher, in case the buffer was
174          * full and we disabled it. */
175         r = add_event_to_connection(sender, EPOLLIN);
176         if (r < 0) {
177                 log_error("Error %d enabling EPOLLIN for fd=%d: %s", r, sender->fd, strerror(-r));
178                 return r;
179         }
180
181         /* Disable the other side's send watcher, as we have no data to send now. */
182         r = remove_event_from_connection(receiver, EPOLLOUT);
183         if (r < 0) {
184                 log_error("Error %d disabling EPOLLOUT for fd=%d: %s", r, receiver->fd, strerror(-r));
185                 return r;
186         }
187
188         return 0;
189 }
190
191 static int transfer_data_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) {
192         struct connection *c = (struct connection *) userdata;
193         int r = 0;
194         ssize_t len;
195
196         assert(revents & (EPOLLIN | EPOLLOUT));
197         assert(fd == c->fd);
198         assert(s == c->w);
199
200         log_debug("Got event revents=%d from fd=%d (conn %p).", revents, fd, c);
201
202         if (revents & EPOLLIN) {
203                 log_debug("About to recv up to %lu bytes from fd=%d (%lu/BUFFER_SIZE).", BUFFER_SIZE - c->buffer_filled_len, fd, c->buffer_filled_len);
204
205                 /* Receive until the buffer's full, there's no more data,
206                  * or the client/server disconnects. */
207                 while (c->buffer_filled_len < BUFFER_SIZE) {
208                         len = recv(fd, c->buffer + c->buffer_filled_len, BUFFER_SIZE - c->buffer_filled_len, 0);
209                         log_debug("recv(%d, ...)=%ld", fd, len);
210                         if (len < 0) {
211                                 if (errno != EWOULDBLOCK && errno != EAGAIN) {
212                                         log_error("Error %d in recv from fd=%d: %m", errno, fd);
213                                         return -errno;
214                                 }
215                                 else {
216                                         /* recv() is in a blocking state. */
217                                         break;
218                                 }
219                         }
220                         else if (len == 0) {
221                                 log_debug("Clean disconnection from fd=%d", fd);
222                                 total_clients--;
223                                 free_connection(c->c_destination);
224                                 free_connection(c);
225                                 return 0;
226                         }
227
228                         assert(len > 0);
229                         log_debug("Recording that the buffer got %ld more bytes full.", len);
230                         c->buffer_filled_len += len;
231                         log_debug("Buffer now has %ld bytes full.", c->buffer_filled_len);
232                 }
233
234                 /* Try sending the data immediately. */
235                 return send_buffer(c);
236         }
237         else {
238                 return send_buffer(c->c_destination);
239         }
240
241         return r;
242 }
243
244 /* Once sending to the server is unblocked, set up the real watchers. */
245 static int connected_to_server_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) {
246         struct sd_event *e = NULL;
247         struct connection *c_server_to_client = (struct connection *) userdata;
248         struct connection *c_client_to_server = c_server_to_client->c_destination;
249         int r;
250
251         assert(revents & EPOLLOUT);
252
253         e = sd_event_get(s);
254
255         /* Cancel the initial write watcher for the server. */
256         sd_event_source_unref(s);
257
258         log_debug("Connected to server. Initializing watchers for receiving data.");
259
260         /* A recv watcher for the server. */
261         r = sd_event_add_io(e, c_server_to_client->fd, EPOLLIN, transfer_data_cb, c_server_to_client, &c_server_to_client->w);
262         if (r < 0) {
263                 log_error("Error %d creating recv watcher for fd=%d: %s", r, c_server_to_client->fd, strerror(-r));
264                 goto fail;
265         }
266         c_server_to_client->events = EPOLLIN;
267
268         /* A recv watcher for the client. */
269         r = sd_event_add_io(e, c_client_to_server->fd, EPOLLIN, transfer_data_cb, c_client_to_server, &c_client_to_server->w);
270         if (r < 0) {
271                 log_error("Error %d creating recv watcher for fd=%d: %s", r, c_client_to_server->fd, strerror(-r));
272                 goto fail;
273         }
274         c_client_to_server->events = EPOLLIN;
275
276 goto finish;
277
278 fail:
279         free_connection(c_client_to_server);
280         free_connection(c_server_to_client);
281
282 finish:
283         return r;
284 }
285
286 static int get_server_connection_fd(const struct proxy *proxy) {
287         int server_fd;
288         int r = -EBADF;
289         int len;
290
291         if (proxy->remote_is_inet) {
292                 int s;
293                 _cleanup_freeaddrinfo_ struct addrinfo *result = NULL;
294                 struct addrinfo hints = {.ai_family = AF_UNSPEC,
295                                          .ai_socktype = SOCK_STREAM,
296                                          .ai_flags = AI_PASSIVE};
297
298                 log_debug("Looking up address info for %s:%s", proxy->remote_host, proxy->remote_service);
299                 s = getaddrinfo(proxy->remote_host, proxy->remote_service, &hints, &result);
300                 if (s != 0) {
301                         log_error("getaddrinfo error (%d): %s", s, gai_strerror(s));
302                         return r;
303                 }
304
305                 if (result == NULL) {
306                         log_error("getaddrinfo: no result");
307                         return r;
308                 }
309
310                 /* @TODO: Try connecting to all results instead of just the first. */
311                 server_fd = socket(result->ai_family, result->ai_socktype | SOCK_NONBLOCK, result->ai_protocol);
312                 if (server_fd < 0) {
313                         log_error("Error %d creating socket: %m", errno);
314                         return r;
315                 }
316
317                 r = connect(server_fd, result->ai_addr, result->ai_addrlen);
318                 /* Ignore EINPROGRESS errors because they're expected for a nonblocking socket. */
319                 if (r < 0 && errno != EINPROGRESS) {
320                         log_error("Error %d while connecting to socket %s:%s: %m", errno, proxy->remote_host, proxy->remote_service);
321                         return r;
322                 }
323         }
324         else {
325                 struct sockaddr_un remote;
326
327                 server_fd = socket(AF_UNIX, SOCK_STREAM | SOCK_NONBLOCK, 0);
328                 if (server_fd < 0) {
329                         log_error("Error %d creating socket: %m", errno);
330                         return -EBADFD;
331                 }
332
333                 remote.sun_family = AF_UNIX;
334                 strncpy(remote.sun_path, proxy->remote_host, sizeof(remote.sun_path));
335                 len = strlen(remote.sun_path) + sizeof(remote.sun_family);
336                 r = connect(server_fd, (struct sockaddr *) &remote, len);
337                 if (r < 0 && errno != EINPROGRESS) {
338                         log_error("Error %d while connecting to Unix domain socket %s: %m", errno, proxy->remote_host);
339                         return -EBADFD;
340                 }
341         }
342
343         log_debug("Server connection is fd=%d", server_fd);
344         return server_fd;
345 }
346
347 static int accept_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) {
348         struct proxy *proxy = (struct proxy *) userdata;
349         struct connection *c_server_to_client;
350         struct connection *c_client_to_server;
351         int r = 0;
352         union sockaddr_union sa;
353         socklen_t salen = sizeof(sa);
354
355         assert(revents & EPOLLIN);
356
357         c_server_to_client = malloc0(sizeof(struct connection));
358         if (c_server_to_client == NULL) {
359                 log_oom();
360                 goto fail;
361         }
362
363         c_client_to_server = malloc0(sizeof(struct connection));
364         if (c_client_to_server == NULL) {
365                 log_oom();
366                 goto fail;
367         }
368
369         c_server_to_client->fd = get_server_connection_fd(proxy);
370         if (c_server_to_client->fd < 0) {
371                 log_error("Error initiating server connection.");
372                 goto fail;
373         }
374
375         c_client_to_server->fd = accept(fd, (struct sockaddr *) &sa, &salen);
376         if (c_client_to_server->fd < 0) {
377                 log_error("Error accepting client connection.");
378                 goto fail;
379         }
380
381         /* Unlike on BSD, client sockets do not inherit nonblocking status
382          * from the listening socket. */
383         r = fd_nonblock(c_client_to_server->fd, true);
384         if (r < 0) {
385                 log_error("Error %d marking client connection as nonblocking: %s", r, strerror(-r));
386                 goto fail;
387         }
388
389         if (sa.sa.sa_family == AF_INET || sa.sa.sa_family == AF_INET6) {
390                 char sa_str[INET6_ADDRSTRLEN];
391                 const char *success;
392
393                 success = inet_ntop(sa.sa.sa_family, &sa.in6.sin6_addr, sa_str, INET6_ADDRSTRLEN);
394                 if (success == NULL)
395                         log_warning("Error %d calling inet_ntop: %m", errno);
396                 else
397                         log_debug("Accepted client connection from %s as fd=%d", sa_str, c_client_to_server->fd);
398         }
399         else {
400                 log_debug("Accepted client connection (non-IP) as fd=%d", c_client_to_server->fd);
401         }
402
403         total_clients++;
404         log_debug("Client fd=%d (conn %p) successfully connected. Total clients: %u", c_client_to_server->fd, c_client_to_server, total_clients);
405         log_debug("Server fd=%d (conn %p) successfully initialized.", c_server_to_client->fd, c_server_to_client);
406
407         /* Initialize watcher for send to server; this shows connectivity. */
408         r = sd_event_add_io(sd_event_get(s), c_server_to_client->fd, EPOLLOUT, connected_to_server_cb, c_server_to_client, &c_server_to_client->w);
409         if (r < 0) {
410                 log_error("Error %d creating connectivity watcher for fd=%d: %s", r, c_server_to_client->fd, strerror(-r));
411                 goto fail;
412         }
413
414         /* Allow lookups of the opposite connection. */
415         c_server_to_client->c_destination = c_client_to_server;
416         c_client_to_server->c_destination = c_server_to_client;
417
418         goto finish;
419
420 fail:
421         log_warning("Accepting a client connection or connecting to the server failed.");
422         free_connection(c_client_to_server);
423         free_connection(c_server_to_client);
424
425 finish:
426         /* Preserve the main loop even if a single proxy setup fails. */
427         return 0;
428 }
429
430 static int run_main_loop(struct proxy *proxy) {
431         int r = EXIT_SUCCESS;
432         struct sd_event *e = NULL;
433         sd_event_source *w_accept = NULL;
434
435         r = sd_event_new(&e);
436         if (r < 0)
437                 goto finish;
438
439         r = fd_nonblock(proxy->listen_fd, true);
440         if (r < 0)
441                 goto finish;
442
443         log_debug("Initializing main listener fd=%d", proxy->listen_fd);
444
445         sd_event_add_io(e, proxy->listen_fd, EPOLLIN, accept_cb, proxy, &w_accept);
446
447         log_debug("Initialized main listener. Entering loop.");
448
449         sd_event_loop(e);
450
451 finish:
452         sd_event_source_unref(w_accept);
453         sd_event_unref(e);
454
455         return r;
456 }
457
458 static int help(void) {
459
460         printf("%s hostname-or-ip port-or-service\n"
461                "%s unix-domain-socket-path\n\n"
462                "Inherit a socket. Bidirectionally proxy.\n\n"
463                "  -h --help       Show this help\n"
464                "  --version       Print version and exit\n"
465                "  --ignore-env    Ignore expected systemd environment\n",
466                program_invocation_short_name,
467                program_invocation_short_name);
468
469         return 0;
470 }
471
472 static void version(void) {
473         puts(PACKAGE_STRING " saproxy");
474 }
475
476 static int parse_argv(int argc, char *argv[], struct proxy *p) {
477
478         enum {
479                 ARG_VERSION = 0x100,
480                 ARG_IGNORE_ENV
481         };
482
483         static const struct option options[] = {
484                 { "help",       no_argument, NULL, 'h'           },
485                 { "version",    no_argument, NULL, ARG_VERSION   },
486                 { "ignore-env", no_argument, NULL, ARG_IGNORE_ENV},
487                 { NULL,         0,           NULL, 0             }
488         };
489
490         int c;
491
492         assert(argc >= 0);
493         assert(argv);
494
495         while ((c = getopt_long(argc, argv, "h", options, NULL)) >= 0) {
496
497                 switch (c) {
498
499                 case 'h':
500                         help();
501                         return 0;
502
503                 case '?':
504                         return -EINVAL;
505
506                 case ARG_VERSION:
507                         version();
508                         return 0;
509
510                 case ARG_IGNORE_ENV:
511                         p->ignore_env = true;
512                         continue;
513
514                 default:
515                         log_error("Unknown option code %c", c);
516                         return -EINVAL;
517                 }
518         }
519
520         if (optind + 1 != argc && optind + 2 != argc) {
521                 log_error("Incorrect number of positional arguments.");
522                 help();
523                 return -EINVAL;
524         }
525
526         p->remote_host = argv[optind];
527         assert(p->remote_host);
528
529         p->remote_is_inet = p->remote_host[0] != '/';
530
531         if (optind == argc - 2) {
532                 if (!p->remote_is_inet) {
533                         log_error("A port or service is not allowed for Unix socket destinations.");
534                         help();
535                         return -EINVAL;
536                 }
537                 p->remote_service = argv[optind + 1];
538                 assert(p->remote_service);
539         } else if (p->remote_is_inet) {
540                 log_error("A port or service is required for IP destinations.");
541                 help();
542                 return -EINVAL;
543         }
544
545         return 1;
546 }
547
548 int main(int argc, char *argv[]) {
549         struct proxy p = {};
550         int r;
551
552         log_parse_environment();
553         log_open();
554
555         r = parse_argv(argc, argv, &p);
556         if (r <= 0)
557                 goto finish;
558
559         p.listen_fd = SD_LISTEN_FDS_START;
560
561         if (!p.ignore_env) {
562             int n;
563             n = sd_listen_fds(1);
564             if (n == 0) {
565                     log_error("Found zero inheritable sockets. Are you sure this is running as a socket-activated service?");
566                     r = EXIT_FAILURE;
567                     goto finish;
568             } else if (n < 0) {
569                     log_error("Error %d while finding inheritable sockets: %s", n, strerror(-n));
570                     r = EXIT_FAILURE;
571                     goto finish;
572             } else if (n > 1) {
573                     log_error("Can't listen on more than one socket.");
574                     r = EXIT_FAILURE;
575                     goto finish;
576             }
577         }
578
579         /* @TODO: Check if this proxy can work with datagram sockets. */
580         r = sd_is_socket(p.listen_fd, 0, SOCK_STREAM, 1);
581         if (r < 0) {
582                 log_error("Error %d while checking inherited socket: %s", r, strerror(-r));
583                 goto finish;
584         }
585
586         log_info("Starting the socket activation proxy with listener fd=%d.", p.listen_fd);
587
588         r = run_main_loop(&p);
589         if (r < 0) {
590                 log_error("Error %d from main loop.", r);
591                 goto finish;
592         }
593
594 finish:
595         log_close();
596         return r < 0 ? EXIT_FAILURE : EXIT_SUCCESS;
597 }