chiark / gitweb /
clients: unify how we invoke getopt_long()
[elogind.git] / src / socket-proxy / socket-proxyd.c
1 /*-*- Mode: C; c-basic-offset: 8; indent-tabs-mode: nil -*-*/
2
3 /***
4   This file is part of systemd.
5
6   Copyright 2013 David Strauss
7
8   systemd is free software; you can redistribute it and/or modify it
9   under the terms of the GNU Lesser General Public License as published by
10   the Free Software Foundation; either version 2.1 of the License, or
11   (at your option) any later version.
12
13   systemd is distributed in the hope that it will be useful, but
14   WITHOUT ANY WARRANTY; without even the implied warranty of
15   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
16   Lesser General Public License for more details.
17
18   You should have received a copy of the GNU Lesser General Public License
19   along with systemd; If not, see <http://www.gnu.org/licenses/>.
20  ***/
21
22 #include <arpa/inet.h>
23 #include <errno.h>
24 #include <getopt.h>
25 #include <stdio.h>
26 #include <stdlib.h>
27 #include <string.h>
28 #include <netdb.h>
29 #include <sys/fcntl.h>
30 #include <sys/socket.h>
31 #include <sys/un.h>
32 #include <unistd.h>
33
34 #include "sd-daemon.h"
35 #include "sd-event.h"
36 #include "log.h"
37 #include "socket-util.h"
38 #include "util.h"
39 #include "event-util.h"
40 #include "build.h"
41
42 #define BUFFER_SIZE 16384
43 #define _cleanup_freeaddrinfo_ _cleanup_(freeaddrinfop)
44
45 unsigned int total_clients = 0;
46
47 DEFINE_TRIVIAL_CLEANUP_FUNC(struct addrinfo *, freeaddrinfo);
48
49 struct proxy {
50         int listen_fd;
51         bool ignore_env;
52         bool remote_is_inet;
53         const char *remote_host;
54         const char *remote_service;
55 };
56
57 struct connection {
58         int fd;
59         uint32_t events;
60         sd_event_source *w;
61         struct connection *c_destination;
62         size_t buffer_filled_len;
63         size_t buffer_sent_len;
64         char buffer[BUFFER_SIZE];
65 };
66
67 static void free_connection(struct connection *c) {
68         if (c != NULL) {
69                 log_debug("Freeing fd=%d (conn %p).", c->fd, c);
70                 sd_event_source_unref(c->w);
71                 if (c->fd > 0)
72                         close_nointr_nofail(c->fd);
73                 free(c);
74         }
75 }
76
77 static int add_event_to_connection(struct connection *c, uint32_t events) {
78         int r;
79
80         log_debug("Have revents=%d. Adding revents=%d.", c->events, events);
81
82         c->events |= events;
83
84         r = sd_event_source_set_io_events(c->w, c->events);
85         if (r < 0) {
86                 log_error("Error %d setting revents: %s", r, strerror(-r));
87                 return r;
88         }
89
90         r = sd_event_source_set_enabled(c->w, SD_EVENT_ON);
91         if (r < 0) {
92                 log_error("Error %d enabling source: %s", r, strerror(-r));
93                 return r;
94         }
95
96         return 0;
97 }
98
99 static int remove_event_from_connection(struct connection *c, uint32_t events) {
100         int r;
101
102         log_debug("Have revents=%d. Removing revents=%d.", c->events, events);
103
104         c->events &= ~events;
105
106         r = sd_event_source_set_io_events(c->w, c->events);
107         if (r < 0) {
108                 log_error("Error %d setting revents: %s", r, strerror(-r));
109                 return r;
110         }
111
112         if (c->events == 0) {
113             r = sd_event_source_set_enabled(c->w, SD_EVENT_OFF);
114             if (r < 0) {
115                     log_error("Error %d disabling source: %s", r, strerror(-r));
116                     return r;
117             }
118         }
119
120         return 0;
121 }
122
123 static int send_buffer(struct connection *sender) {
124         struct connection *receiver = sender->c_destination;
125         ssize_t len;
126         int r = 0;
127
128         /* We cannot assume that even a partial send() indicates that
129          * the next send() will return EAGAIN or EWOULDBLOCK. Loop until
130          * it does. */
131         while (sender->buffer_filled_len > sender->buffer_sent_len) {
132                 len = send(receiver->fd, sender->buffer + sender->buffer_sent_len, sender->buffer_filled_len - sender->buffer_sent_len, 0);
133                 log_debug("send(%d, ...)=%zd", receiver->fd, len);
134                 if (len < 0) {
135                         if (errno != EWOULDBLOCK && errno != EAGAIN) {
136                                 log_error("Error %d in send to fd=%d: %m", errno, receiver->fd);
137                                 return -errno;
138                         }
139                         else {
140                                 /* send() is in a would-block state. */
141                                 break;
142                         }
143                 }
144
145                 /* len < 0 can't occur here. len == 0 is possible but
146                  * undefined behavior for nonblocking send(). */
147                 assert(len > 0);
148                 sender->buffer_sent_len += len;
149         }
150
151         log_debug("send(%d, ...) completed with %zu bytes still buffered.", receiver->fd, sender->buffer_filled_len - sender->buffer_sent_len);
152
153         /* Detect a would-block state or partial send. */
154         if (sender->buffer_filled_len > sender->buffer_sent_len) {
155
156                 /* If the buffer is full, disable events coming for recv. */
157                 if (sender->buffer_filled_len == BUFFER_SIZE) {
158                     r = remove_event_from_connection(sender, EPOLLIN);
159                     if (r < 0) {
160                             log_error("Error %d disabling EPOLLIN for fd=%d: %s", r, sender->fd, strerror(-r));
161                             return r;
162                     }
163                 }
164
165                 /* Watch for when the recipient can be sent data again. */
166                 r = add_event_to_connection(receiver, EPOLLOUT);
167                 if (r < 0) {
168                         log_error("Error %d enabling EPOLLOUT for fd=%d: %s", r, receiver->fd, strerror(-r));
169                         return r;
170                 }
171                 log_debug("Done with recv for fd=%d. Waiting on send for fd=%d.", sender->fd, receiver->fd);
172                 return r;
173         }
174
175         /* If we sent everything without any issues (would-block or
176          * partial send), the buffer is now empty. */
177         sender->buffer_filled_len = 0;
178         sender->buffer_sent_len = 0;
179
180         /* Enable the sender's receive watcher, in case the buffer was
181          * full and we disabled it. */
182         r = add_event_to_connection(sender, EPOLLIN);
183         if (r < 0) {
184                 log_error("Error %d enabling EPOLLIN for fd=%d: %s", r, sender->fd, strerror(-r));
185                 return r;
186         }
187
188         /* Disable the other side's send watcher, as we have no data to send now. */
189         r = remove_event_from_connection(receiver, EPOLLOUT);
190         if (r < 0) {
191                 log_error("Error %d disabling EPOLLOUT for fd=%d: %s", r, receiver->fd, strerror(-r));
192                 return r;
193         }
194
195         return 0;
196 }
197
198 static int transfer_data_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) {
199         struct connection *c = (struct connection *) userdata;
200         int r = 0;
201         ssize_t len;
202
203         assert(revents & (EPOLLIN | EPOLLOUT));
204         assert(fd == c->fd);
205         assert(s == c->w);
206
207         log_debug("Got event revents=%d from fd=%d (conn %p).", revents, fd, c);
208
209         if (revents & EPOLLIN) {
210                 log_debug("About to recv up to %zu bytes from fd=%d (%zu/BUFFER_SIZE).", BUFFER_SIZE - c->buffer_filled_len, fd, c->buffer_filled_len);
211
212                 /* Receive until the buffer's full, there's no more data,
213                  * or the client/server disconnects. */
214                 while (c->buffer_filled_len < BUFFER_SIZE) {
215                         len = recv(fd, c->buffer + c->buffer_filled_len, BUFFER_SIZE - c->buffer_filled_len, 0);
216                         log_debug("recv(%d, ...)=%zd", fd, len);
217                         if (len < 0) {
218                                 if (errno != EWOULDBLOCK && errno != EAGAIN) {
219                                         log_error("Error %d in recv from fd=%d: %m", errno, fd);
220                                         return -errno;
221                                 }
222                                 else {
223                                         /* recv() is in a blocking state. */
224                                         break;
225                                 }
226                         }
227                         else if (len == 0) {
228                                 log_debug("Clean disconnection from fd=%d", fd);
229                                 total_clients--;
230                                 free_connection(c->c_destination);
231                                 free_connection(c);
232                                 return 0;
233                         }
234
235                         assert(len > 0);
236                         log_debug("Recording that the buffer got %zd more bytes full.", len);
237                         c->buffer_filled_len += len;
238                         log_debug("Buffer now has %zu bytes full.", c->buffer_filled_len);
239                 }
240
241                 /* Try sending the data immediately. */
242                 return send_buffer(c);
243         }
244         else {
245                 return send_buffer(c->c_destination);
246         }
247
248         return r;
249 }
250
251 /* Once sending to the server is ready, set up the real watchers. */
252 static int connected_to_server_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) {
253         struct sd_event *e = NULL;
254         struct connection *c_server_to_client = (struct connection *) userdata;
255         struct connection *c_client_to_server = c_server_to_client->c_destination;
256         int r;
257
258         assert(revents & EPOLLOUT);
259
260         e = sd_event_get(s);
261
262         /* Cancel the initial write watcher for the server. */
263         sd_event_source_unref(s);
264
265         log_debug("Connected to server. Initializing watchers for receiving data.");
266
267         /* A recv watcher for the server. */
268         r = sd_event_add_io(e, c_server_to_client->fd, EPOLLIN, transfer_data_cb, c_server_to_client, &c_server_to_client->w);
269         if (r < 0) {
270                 log_error("Error %d creating recv watcher for fd=%d: %s", r, c_server_to_client->fd, strerror(-r));
271                 goto fail;
272         }
273         c_server_to_client->events = EPOLLIN;
274
275         /* A recv watcher for the client. */
276         r = sd_event_add_io(e, c_client_to_server->fd, EPOLLIN, transfer_data_cb, c_client_to_server, &c_client_to_server->w);
277         if (r < 0) {
278                 log_error("Error %d creating recv watcher for fd=%d: %s", r, c_client_to_server->fd, strerror(-r));
279                 goto fail;
280         }
281         c_client_to_server->events = EPOLLIN;
282
283 goto finish;
284
285 fail:
286         free_connection(c_client_to_server);
287         free_connection(c_server_to_client);
288
289 finish:
290         return r;
291 }
292
293 static int get_server_connection_fd(const struct proxy *proxy) {
294         int server_fd;
295         int r = -EBADF;
296         int len;
297
298         if (proxy->remote_is_inet) {
299                 int s;
300                 _cleanup_freeaddrinfo_ struct addrinfo *result = NULL;
301                 struct addrinfo hints = {.ai_family = AF_UNSPEC,
302                                          .ai_socktype = SOCK_STREAM,
303                                          .ai_flags = AI_PASSIVE};
304
305                 log_debug("Looking up address info for %s:%s", proxy->remote_host, proxy->remote_service);
306                 s = getaddrinfo(proxy->remote_host, proxy->remote_service, &hints, &result);
307                 if (s != 0) {
308                         log_error("getaddrinfo error (%d): %s", s, gai_strerror(s));
309                         return r;
310                 }
311
312                 if (result == NULL) {
313                         log_error("getaddrinfo: no result");
314                         return r;
315                 }
316
317                 /* @TODO: Try connecting to all results instead of just the first. */
318                 server_fd = socket(result->ai_family, result->ai_socktype | SOCK_NONBLOCK, result->ai_protocol);
319                 if (server_fd < 0) {
320                         log_error("Error %d creating socket: %m", errno);
321                         return r;
322                 }
323
324                 r = connect(server_fd, result->ai_addr, result->ai_addrlen);
325                 /* Ignore EINPROGRESS errors because they're expected for a nonblocking socket. */
326                 if (r < 0 && errno != EINPROGRESS) {
327                         log_error("Error %d while connecting to socket %s:%s: %m", errno, proxy->remote_host, proxy->remote_service);
328                         return r;
329                 }
330         }
331         else {
332                 struct sockaddr_un remote;
333
334                 server_fd = socket(AF_UNIX, SOCK_STREAM | SOCK_NONBLOCK, 0);
335                 if (server_fd < 0) {
336                         log_error("Error %d creating socket: %m", errno);
337                         return -EBADFD;
338                 }
339
340                 remote.sun_family = AF_UNIX;
341                 strncpy(remote.sun_path, proxy->remote_host, sizeof(remote.sun_path));
342                 len = strlen(remote.sun_path) + sizeof(remote.sun_family);
343                 r = connect(server_fd, (struct sockaddr *) &remote, len);
344                 if (r < 0 && errno != EINPROGRESS) {
345                         log_error("Error %d while connecting to Unix domain socket %s: %m", errno, proxy->remote_host);
346                         return -EBADFD;
347                 }
348         }
349
350         log_debug("Server connection is fd=%d", server_fd);
351         return server_fd;
352 }
353
354 static int do_accept(sd_event *e, struct proxy *p, int fd) {
355         struct connection *c_server_to_client = NULL;
356         struct connection *c_client_to_server = NULL;
357         int r = 0;
358         union sockaddr_union sa;
359         socklen_t salen = sizeof(sa);
360         int client_fd, server_fd;
361
362         client_fd = accept4(fd, (struct sockaddr *) &sa, &salen, SOCK_NONBLOCK|SOCK_CLOEXEC);
363         if (client_fd < 0) {
364                 if (errno == EAGAIN || errno == EWOULDBLOCK)
365                     return -errno;
366                 log_error("Error %d accepting client connection: %m", errno);
367                 r = -errno;
368                 goto fail;
369         }
370
371         server_fd = get_server_connection_fd(p);
372         if (server_fd < 0) {
373                 log_error("Error initiating server connection.");
374                 r = server_fd;
375                 goto fail;
376         }
377
378         c_client_to_server = new0(struct connection, 1);
379         if (c_client_to_server == NULL) {
380                 log_oom();
381                 goto fail;
382         }
383
384         c_server_to_client = new0(struct connection, 1);
385         if (c_server_to_client == NULL) {
386                 log_oom();
387                 goto fail;
388         }
389
390         c_client_to_server->fd = client_fd;
391         c_server_to_client->fd = server_fd;
392
393         if (sa.sa.sa_family == AF_INET || sa.sa.sa_family == AF_INET6) {
394                 char sa_str[INET6_ADDRSTRLEN];
395                 const char *success;
396
397                 success = inet_ntop(sa.sa.sa_family, &sa.in6.sin6_addr, sa_str, INET6_ADDRSTRLEN);
398                 if (success == NULL)
399                         log_warning("Error %d calling inet_ntop: %m", errno);
400                 else
401                         log_debug("Accepted client connection from %s as fd=%d", sa_str, c_client_to_server->fd);
402         }
403         else {
404                 log_debug("Accepted client connection (non-IP) as fd=%d", c_client_to_server->fd);
405         }
406
407         total_clients++;
408         log_debug("Client fd=%d (conn %p) successfully connected. Total clients: %u", c_client_to_server->fd, c_client_to_server, total_clients);
409         log_debug("Server fd=%d (conn %p) successfully initialized.", c_server_to_client->fd, c_server_to_client);
410
411         /* Initialize watcher for send to server; this shows connectivity. */
412         r = sd_event_add_io(e, c_server_to_client->fd, EPOLLOUT, connected_to_server_cb, c_server_to_client, &c_server_to_client->w);
413         if (r < 0) {
414                 log_error("Error %d creating connectivity watcher for fd=%d: %s", r, c_server_to_client->fd, strerror(-r));
415                 goto fail;
416         }
417
418         /* Allow lookups of the opposite connection. */
419         c_server_to_client->c_destination = c_client_to_server;
420         c_client_to_server->c_destination = c_server_to_client;
421
422         goto finish;
423
424 fail:
425         log_warning("Accepting a client connection or connecting to the server failed.");
426         free_connection(c_client_to_server);
427         free_connection(c_server_to_client);
428
429 finish:
430         return r;
431 }
432
433 static int accept_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) {
434         struct proxy *p = (struct proxy *) userdata;
435         sd_event *e = NULL;
436         int r = 0;
437
438         assert(revents & EPOLLIN);
439
440         e = sd_event_get(s);
441
442         for (;;) {
443                 r = do_accept(e, p, fd);
444                 if (r == -EAGAIN || r == -EWOULDBLOCK)
445                         break;
446                 if (r < 0) {
447                         log_error("Error %d while trying to accept: %s", r, strerror(-r));
448                         break;
449                 }
450         }
451
452         /* Re-enable the watcher. */
453         r = sd_event_source_set_enabled(s, SD_EVENT_ONESHOT);
454         if (r < 0) {
455                 log_error("Error %d while re-enabling listener with ONESHOT: %s", r, strerror(-r));
456                 return r;
457         }
458
459         /* Preserve the main loop even if a single accept() fails. */
460         return 1;
461 }
462
463 static int run_main_loop(struct proxy *proxy) {
464         _cleanup_event_source_unref_ sd_event_source *w_accept = NULL;
465         _cleanup_event_unref_ sd_event *e = NULL;
466         int r = EXIT_SUCCESS;
467
468         r = sd_event_new(&e);
469         if (r < 0) {
470                 log_error("Failed to allocate event loop: %s", strerror(-r));
471                 return r;
472         }
473
474         r = fd_nonblock(proxy->listen_fd, true);
475         if (r < 0) {
476                 log_error("Failed to make listen file descriptor nonblocking: %s", strerror(-r));
477                 return r;
478         }
479
480         log_debug("Initializing main listener fd=%d", proxy->listen_fd);
481
482         r = sd_event_add_io(e, proxy->listen_fd, EPOLLIN, accept_cb, proxy, &w_accept);
483         if (r < 0) {
484                 log_error("Error %d while adding event IO source: %s", r, strerror(-r));
485                 return r;
486         }
487
488         /* Set the watcher to oneshot in case other processes are also
489          * watching to accept(). */
490         r = sd_event_source_set_enabled(w_accept, SD_EVENT_ONESHOT);
491         if (r < 0) {
492                 log_error("Error %d while setting event IO source to ONESHOT: %s", r, strerror(-r));
493                 return r;
494         }
495
496         log_debug("Initialized main listener. Entering loop.");
497
498         return sd_event_loop(e);
499 }
500
501 static int help(void) {
502
503         printf("%s hostname-or-ip port-or-service\n"
504                "%s unix-domain-socket-path\n\n"
505                "Inherit a socket. Bidirectionally proxy.\n\n"
506                "  -h --help       Show this help\n"
507                "  --version       Print version and exit\n"
508                "  --ignore-env    Ignore expected systemd environment\n",
509                program_invocation_short_name,
510                program_invocation_short_name);
511
512         return 0;
513 }
514
515 static int parse_argv(int argc, char *argv[], struct proxy *p) {
516
517         enum {
518                 ARG_VERSION = 0x100,
519                 ARG_IGNORE_ENV
520         };
521
522         static const struct option options[] = {
523                 { "help",       no_argument, NULL, 'h'           },
524                 { "version",    no_argument, NULL, ARG_VERSION   },
525                 { "ignore-env", no_argument, NULL, ARG_IGNORE_ENV},
526                 {}
527         };
528
529         int c;
530
531         assert(argc >= 0);
532         assert(argv);
533
534         while ((c = getopt_long(argc, argv, "h", options, NULL)) >= 0) {
535
536                 switch (c) {
537
538                 case 'h':
539                         return help();
540
541                 case ARG_VERSION:
542                         puts(PACKAGE_STRING);
543                         puts(SYSTEMD_FEATURES);
544                         return 0;
545
546                 case ARG_IGNORE_ENV:
547                         p->ignore_env = true;
548                         continue;
549
550                 case '?':
551                         return -EINVAL;
552
553                 default:
554                         assert_not_reached("Unhandled option");
555                 }
556         }
557
558         if (optind + 1 != argc && optind + 2 != argc) {
559                 log_error("Incorrect number of positional arguments.");
560                 help();
561                 return -EINVAL;
562         }
563
564         p->remote_host = argv[optind];
565         assert(p->remote_host);
566
567         p->remote_is_inet = p->remote_host[0] != '/';
568
569         if (optind == argc - 2) {
570                 if (!p->remote_is_inet) {
571                         log_error("A port or service is not allowed for Unix socket destinations.");
572                         help();
573                         return -EINVAL;
574                 }
575                 p->remote_service = argv[optind + 1];
576                 assert(p->remote_service);
577         } else if (p->remote_is_inet) {
578                 log_error("A port or service is required for IP destinations.");
579                 help();
580                 return -EINVAL;
581         }
582
583         return 1;
584 }
585
586 int main(int argc, char *argv[]) {
587         struct proxy p = {};
588         int r;
589
590         log_parse_environment();
591         log_open();
592
593         r = parse_argv(argc, argv, &p);
594         if (r <= 0)
595                 goto finish;
596
597         p.listen_fd = SD_LISTEN_FDS_START;
598
599         if (!p.ignore_env) {
600                 int n;
601                 n = sd_listen_fds(1);
602                 if (n == 0) {
603                         log_error("Found zero inheritable sockets. Are you sure this is running as a socket-activated service?");
604                         r = EXIT_FAILURE;
605                         goto finish;
606                 } else if (n < 0) {
607                         log_error("Error %d while finding inheritable sockets: %s", n, strerror(-n));
608                         r = EXIT_FAILURE;
609                         goto finish;
610                 } else if (n > 1) {
611                         log_error("Can't listen on more than one socket.");
612                         r = EXIT_FAILURE;
613                         goto finish;
614                 }
615         }
616
617         r = sd_is_socket(p.listen_fd, 0, SOCK_STREAM, 1);
618         if (r < 0) {
619                 log_error("Error %d while checking inherited socket: %s", r, strerror(-r));
620                 goto finish;
621         }
622
623         log_info("Starting the socket activation proxy with listener fd=%d.", p.listen_fd);
624
625         r = run_main_loop(&p);
626
627 finish:
628         return r < 0 ? EXIT_FAILURE : EXIT_SUCCESS;
629 }