chiark / gitweb /
socket-proxyd: Fix-up from previous change to avoid looping on errors.
[elogind.git] / src / socket-proxy / socket-proxyd.c
1 /*-*- Mode: C; c-basic-offset: 8; indent-tabs-mode: nil -*-*/
2
3 /***
4   This file is part of systemd.
5
6   Copyright 2013 David Strauss
7
8   systemd is free software; you can redistribute it and/or modify it
9   under the terms of the GNU Lesser General Public License as published by
10   the Free Software Foundation; either version 2.1 of the License, or
11   (at your option) any later version.
12
13   systemd is distributed in the hope that it will be useful, but
14   WITHOUT ANY WARRANTY; without even the implied warranty of
15   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
16   Lesser General Public License for more details.
17
18   You should have received a copy of the GNU Lesser General Public License
19   along with systemd; If not, see <http://www.gnu.org/licenses/>.
20  ***/
21
22 #include <arpa/inet.h>
23 #include <errno.h>
24 #include <getopt.h>
25 #include <stdio.h>
26 #include <stdlib.h>
27 #include <string.h>
28 #include <netdb.h>
29 #include <sys/fcntl.h>
30 #include <sys/socket.h>
31 #include <sys/un.h>
32 #include <unistd.h>
33
34 #include "sd-daemon.h"
35 #include "sd-event.h"
36 #include "log.h"
37 #include "socket-util.h"
38 #include "util.h"
39 #include "event-util.h"
40
41 #define BUFFER_SIZE 16384
42 #define _cleanup_freeaddrinfo_ _cleanup_(freeaddrinfop)
43
44 unsigned int total_clients = 0;
45
46 DEFINE_TRIVIAL_CLEANUP_FUNC(struct addrinfo *, freeaddrinfo);
47
48 struct proxy {
49         int listen_fd;
50         bool ignore_env;
51         bool remote_is_inet;
52         const char *remote_host;
53         const char *remote_service;
54 };
55
56 struct connection {
57         int fd;
58         uint32_t events;
59         sd_event_source *w;
60         struct connection *c_destination;
61         size_t buffer_filled_len;
62         size_t buffer_sent_len;
63         char buffer[BUFFER_SIZE];
64 };
65
66 static void free_connection(struct connection *c) {
67         if (c != NULL) {
68                 log_debug("Freeing fd=%d (conn %p).", c->fd, c);
69                 sd_event_source_unref(c->w);
70                 if (c->fd > 0)
71                         close_nointr_nofail(c->fd);
72                 free(c);
73         }
74 }
75
76 static int add_event_to_connection(struct connection *c, uint32_t events) {
77         int r;
78
79         log_debug("Have revents=%d. Adding revents=%d.", c->events, events);
80
81         c->events |= events;
82
83         r = sd_event_source_set_io_events(c->w, c->events);
84         if (r < 0) {
85                 log_error("Error %d setting revents: %s", r, strerror(-r));
86                 return r;
87         }
88
89         r = sd_event_source_set_enabled(c->w, SD_EVENT_ON);
90         if (r < 0) {
91                 log_error("Error %d enabling source: %s", r, strerror(-r));
92                 return r;
93         }
94
95         return 0;
96 }
97
98 static int remove_event_from_connection(struct connection *c, uint32_t events) {
99         int r;
100
101         log_debug("Have revents=%d. Removing revents=%d.", c->events, events);
102
103         c->events &= ~events;
104
105         r = sd_event_source_set_io_events(c->w, c->events);
106         if (r < 0) {
107                 log_error("Error %d setting revents: %s", r, strerror(-r));
108                 return r;
109         }
110
111         if (c->events == 0) {
112             r = sd_event_source_set_enabled(c->w, SD_EVENT_OFF);
113             if (r < 0) {
114                     log_error("Error %d disabling source: %s", r, strerror(-r));
115                     return r;
116             }
117         }
118
119         return 0;
120 }
121
122 static int send_buffer(struct connection *sender) {
123         struct connection *receiver = sender->c_destination;
124         ssize_t len;
125         int r = 0;
126
127         /* We cannot assume that even a partial send() indicates that
128          * the next send() will return EAGAIN or EWOULDBLOCK. Loop until
129          * it does. */
130         while (sender->buffer_filled_len > sender->buffer_sent_len) {
131                 len = send(receiver->fd, sender->buffer + sender->buffer_sent_len, sender->buffer_filled_len - sender->buffer_sent_len, 0);
132                 log_debug("send(%d, ...)=%ld", receiver->fd, len);
133                 if (len < 0) {
134                         if (errno != EWOULDBLOCK && errno != EAGAIN) {
135                                 log_error("Error %d in send to fd=%d: %m", errno, receiver->fd);
136                                 return -errno;
137                         }
138                         else {
139                                 /* send() is in a would-block state. */
140                                 break;
141                         }
142                 }
143
144                 /* len < 0 can't occur here. len == 0 is possible but
145                  * undefined behavior for nonblocking send(). */
146                 assert(len > 0);
147                 sender->buffer_sent_len += len;
148         }
149
150         log_debug("send(%d, ...) completed with %lu bytes still buffered.", receiver->fd, sender->buffer_filled_len - sender->buffer_sent_len);
151
152         /* Detect a would-block state or partial send. */
153         if (sender->buffer_filled_len > sender->buffer_sent_len) {
154
155                 /* If the buffer is full, disable events coming for recv. */
156                 if (sender->buffer_filled_len == BUFFER_SIZE) {
157                     r = remove_event_from_connection(sender, EPOLLIN);
158                     if (r < 0) {
159                             log_error("Error %d disabling EPOLLIN for fd=%d: %s", r, sender->fd, strerror(-r));
160                             return r;
161                     }
162                 }
163
164                 /* Watch for when the recipient can be sent data again. */
165                 r = add_event_to_connection(receiver, EPOLLOUT);
166                 if (r < 0) {
167                         log_error("Error %d enabling EPOLLOUT for fd=%d: %s", r, receiver->fd, strerror(-r));
168                         return r;
169                 }
170                 log_debug("Done with recv for fd=%d. Waiting on send for fd=%d.", sender->fd, receiver->fd);
171                 return r;
172         }
173
174         /* If we sent everything without any issues (would-block or
175          * partial send), the buffer is now empty. */
176         sender->buffer_filled_len = 0;
177         sender->buffer_sent_len = 0;
178
179         /* Enable the sender's receive watcher, in case the buffer was
180          * full and we disabled it. */
181         r = add_event_to_connection(sender, EPOLLIN);
182         if (r < 0) {
183                 log_error("Error %d enabling EPOLLIN for fd=%d: %s", r, sender->fd, strerror(-r));
184                 return r;
185         }
186
187         /* Disable the other side's send watcher, as we have no data to send now. */
188         r = remove_event_from_connection(receiver, EPOLLOUT);
189         if (r < 0) {
190                 log_error("Error %d disabling EPOLLOUT for fd=%d: %s", r, receiver->fd, strerror(-r));
191                 return r;
192         }
193
194         return 0;
195 }
196
197 static int transfer_data_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) {
198         struct connection *c = (struct connection *) userdata;
199         int r = 0;
200         ssize_t len;
201
202         assert(revents & (EPOLLIN | EPOLLOUT));
203         assert(fd == c->fd);
204         assert(s == c->w);
205
206         log_debug("Got event revents=%d from fd=%d (conn %p).", revents, fd, c);
207
208         if (revents & EPOLLIN) {
209                 log_debug("About to recv up to %lu bytes from fd=%d (%lu/BUFFER_SIZE).", BUFFER_SIZE - c->buffer_filled_len, fd, c->buffer_filled_len);
210
211                 /* Receive until the buffer's full, there's no more data,
212                  * or the client/server disconnects. */
213                 while (c->buffer_filled_len < BUFFER_SIZE) {
214                         len = recv(fd, c->buffer + c->buffer_filled_len, BUFFER_SIZE - c->buffer_filled_len, 0);
215                         log_debug("recv(%d, ...)=%ld", fd, len);
216                         if (len < 0) {
217                                 if (errno != EWOULDBLOCK && errno != EAGAIN) {
218                                         log_error("Error %d in recv from fd=%d: %m", errno, fd);
219                                         return -errno;
220                                 }
221                                 else {
222                                         /* recv() is in a blocking state. */
223                                         break;
224                                 }
225                         }
226                         else if (len == 0) {
227                                 log_debug("Clean disconnection from fd=%d", fd);
228                                 total_clients--;
229                                 free_connection(c->c_destination);
230                                 free_connection(c);
231                                 return 0;
232                         }
233
234                         assert(len > 0);
235                         log_debug("Recording that the buffer got %ld more bytes full.", len);
236                         c->buffer_filled_len += len;
237                         log_debug("Buffer now has %ld bytes full.", c->buffer_filled_len);
238                 }
239
240                 /* Try sending the data immediately. */
241                 return send_buffer(c);
242         }
243         else {
244                 return send_buffer(c->c_destination);
245         }
246
247         return r;
248 }
249
250 /* Once sending to the server is ready, set up the real watchers. */
251 static int connected_to_server_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) {
252         struct sd_event *e = NULL;
253         struct connection *c_server_to_client = (struct connection *) userdata;
254         struct connection *c_client_to_server = c_server_to_client->c_destination;
255         int r;
256
257         assert(revents & EPOLLOUT);
258
259         e = sd_event_get(s);
260
261         /* Cancel the initial write watcher for the server. */
262         sd_event_source_unref(s);
263
264         log_debug("Connected to server. Initializing watchers for receiving data.");
265
266         /* A recv watcher for the server. */
267         r = sd_event_add_io(e, c_server_to_client->fd, EPOLLIN, transfer_data_cb, c_server_to_client, &c_server_to_client->w);
268         if (r < 0) {
269                 log_error("Error %d creating recv watcher for fd=%d: %s", r, c_server_to_client->fd, strerror(-r));
270                 goto fail;
271         }
272         c_server_to_client->events = EPOLLIN;
273
274         /* A recv watcher for the client. */
275         r = sd_event_add_io(e, c_client_to_server->fd, EPOLLIN, transfer_data_cb, c_client_to_server, &c_client_to_server->w);
276         if (r < 0) {
277                 log_error("Error %d creating recv watcher for fd=%d: %s", r, c_client_to_server->fd, strerror(-r));
278                 goto fail;
279         }
280         c_client_to_server->events = EPOLLIN;
281
282 goto finish;
283
284 fail:
285         free_connection(c_client_to_server);
286         free_connection(c_server_to_client);
287
288 finish:
289         return r;
290 }
291
292 static int get_server_connection_fd(const struct proxy *proxy) {
293         int server_fd;
294         int r = -EBADF;
295         int len;
296
297         if (proxy->remote_is_inet) {
298                 int s;
299                 _cleanup_freeaddrinfo_ struct addrinfo *result = NULL;
300                 struct addrinfo hints = {.ai_family = AF_UNSPEC,
301                                          .ai_socktype = SOCK_STREAM,
302                                          .ai_flags = AI_PASSIVE};
303
304                 log_debug("Looking up address info for %s:%s", proxy->remote_host, proxy->remote_service);
305                 s = getaddrinfo(proxy->remote_host, proxy->remote_service, &hints, &result);
306                 if (s != 0) {
307                         log_error("getaddrinfo error (%d): %s", s, gai_strerror(s));
308                         return r;
309                 }
310
311                 if (result == NULL) {
312                         log_error("getaddrinfo: no result");
313                         return r;
314                 }
315
316                 /* @TODO: Try connecting to all results instead of just the first. */
317                 server_fd = socket(result->ai_family, result->ai_socktype | SOCK_NONBLOCK, result->ai_protocol);
318                 if (server_fd < 0) {
319                         log_error("Error %d creating socket: %m", errno);
320                         return r;
321                 }
322
323                 r = connect(server_fd, result->ai_addr, result->ai_addrlen);
324                 /* Ignore EINPROGRESS errors because they're expected for a nonblocking socket. */
325                 if (r < 0 && errno != EINPROGRESS) {
326                         log_error("Error %d while connecting to socket %s:%s: %m", errno, proxy->remote_host, proxy->remote_service);
327                         return r;
328                 }
329         }
330         else {
331                 struct sockaddr_un remote;
332
333                 server_fd = socket(AF_UNIX, SOCK_STREAM | SOCK_NONBLOCK, 0);
334                 if (server_fd < 0) {
335                         log_error("Error %d creating socket: %m", errno);
336                         return -EBADFD;
337                 }
338
339                 remote.sun_family = AF_UNIX;
340                 strncpy(remote.sun_path, proxy->remote_host, sizeof(remote.sun_path));
341                 len = strlen(remote.sun_path) + sizeof(remote.sun_family);
342                 r = connect(server_fd, (struct sockaddr *) &remote, len);
343                 if (r < 0 && errno != EINPROGRESS) {
344                         log_error("Error %d while connecting to Unix domain socket %s: %m", errno, proxy->remote_host);
345                         return -EBADFD;
346                 }
347         }
348
349         log_debug("Server connection is fd=%d", server_fd);
350         return server_fd;
351 }
352
353 static int do_accept(sd_event *e, struct proxy *p, int fd) {
354         struct connection *c_server_to_client = NULL;
355         struct connection *c_client_to_server = NULL;
356         int r = 0;
357         union sockaddr_union sa;
358         socklen_t salen = sizeof(sa);
359         int client_fd, server_fd;
360
361         client_fd = accept4(fd, (struct sockaddr *) &sa, &salen, SOCK_NONBLOCK|SOCK_CLOEXEC);
362         if (client_fd < 0) {
363                 if (errno == EAGAIN || errno == EWOULDBLOCK)
364                     return -errno;
365                 log_error("Error %d accepting client connection: %m", errno);
366                 r = -errno;
367                 goto fail;
368         }
369
370         server_fd = get_server_connection_fd(p);
371         if (server_fd < 0) {
372                 log_error("Error initiating server connection.");
373                 r = server_fd;
374                 goto fail;
375         }
376
377         c_client_to_server = new0(struct connection, 1);
378         if (c_client_to_server == NULL) {
379                 log_oom();
380                 goto fail;
381         }
382
383         c_server_to_client = new0(struct connection, 1);
384         if (c_server_to_client == NULL) {
385                 log_oom();
386                 goto fail;
387         }
388
389         c_client_to_server->fd = client_fd;
390         c_server_to_client->fd = server_fd;
391
392         if (sa.sa.sa_family == AF_INET || sa.sa.sa_family == AF_INET6) {
393                 char sa_str[INET6_ADDRSTRLEN];
394                 const char *success;
395
396                 success = inet_ntop(sa.sa.sa_family, &sa.in6.sin6_addr, sa_str, INET6_ADDRSTRLEN);
397                 if (success == NULL)
398                         log_warning("Error %d calling inet_ntop: %m", errno);
399                 else
400                         log_debug("Accepted client connection from %s as fd=%d", sa_str, c_client_to_server->fd);
401         }
402         else {
403                 log_debug("Accepted client connection (non-IP) as fd=%d", c_client_to_server->fd);
404         }
405
406         total_clients++;
407         log_debug("Client fd=%d (conn %p) successfully connected. Total clients: %u", c_client_to_server->fd, c_client_to_server, total_clients);
408         log_debug("Server fd=%d (conn %p) successfully initialized.", c_server_to_client->fd, c_server_to_client);
409
410         /* Initialize watcher for send to server; this shows connectivity. */
411         r = sd_event_add_io(e, c_server_to_client->fd, EPOLLOUT, connected_to_server_cb, c_server_to_client, &c_server_to_client->w);
412         if (r < 0) {
413                 log_error("Error %d creating connectivity watcher for fd=%d: %s", r, c_server_to_client->fd, strerror(-r));
414                 goto fail;
415         }
416
417         /* Allow lookups of the opposite connection. */
418         c_server_to_client->c_destination = c_client_to_server;
419         c_client_to_server->c_destination = c_server_to_client;
420
421         goto finish;
422
423 fail:
424         log_warning("Accepting a client connection or connecting to the server failed.");
425         free_connection(c_client_to_server);
426         free_connection(c_server_to_client);
427
428 finish:
429         return r;
430 }
431
432 static int accept_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) {
433         struct proxy *p = (struct proxy *) userdata;
434         sd_event *e = NULL;
435         int r = 0;
436
437         assert(revents & EPOLLIN);
438
439         e = sd_event_get(s);
440
441         for (;;) {
442                 r = do_accept(e, p, fd);
443                 if (r == -EAGAIN || r == -EWOULDBLOCK)
444                         break;
445                 if (r < 0) {
446                         log_error("Error %d while trying to accept: %s", r, strerror(-r));
447                         break;
448                 }
449         }
450
451         /* Re-enable the watcher. */
452         r = sd_event_source_set_enabled(s, SD_EVENT_ONESHOT);
453         if (r < 0) {
454                 log_error("Error %d while re-enabling listener with ONESHOT: %s", r, strerror(-r));
455                 return r;
456         }
457
458         /* Preserve the main loop even if a single accept() fails. */
459         return 1;
460 }
461
462 static int run_main_loop(struct proxy *proxy) {
463         _cleanup_event_source_unref_ sd_event_source *w_accept = NULL;
464         _cleanup_event_unref_ sd_event *e = NULL;
465         int r = EXIT_SUCCESS;
466
467         r = sd_event_new(&e);
468         if (r < 0) {
469                 log_error("Failed to allocate event loop: %s", strerror(-r));
470                 return r;
471         }
472
473         r = fd_nonblock(proxy->listen_fd, true);
474         if (r < 0) {
475                 log_error("Failed to make listen file descriptor nonblocking: %s", strerror(-r));
476                 return r;
477         }
478
479         log_debug("Initializing main listener fd=%d", proxy->listen_fd);
480
481         r = sd_event_add_io(e, proxy->listen_fd, EPOLLIN, accept_cb, proxy, &w_accept);
482         if (r < 0) {
483                 log_error("Error %d while adding event IO source: %s", r, strerror(-r));
484                 return r;
485         }
486
487         /* Set the watcher to oneshot in case other processes are also
488          * watching to accept(). */
489         r = sd_event_source_set_enabled(w_accept, SD_EVENT_ONESHOT);
490         if (r < 0) {
491                 log_error("Error %d while setting event IO source to ONESHOT: %s", r, strerror(-r));
492                 return r;
493         }
494
495         log_debug("Initialized main listener. Entering loop.");
496
497         return sd_event_loop(e);
498 }
499
500 static int help(void) {
501
502         printf("%s hostname-or-ip port-or-service\n"
503                "%s unix-domain-socket-path\n\n"
504                "Inherit a socket. Bidirectionally proxy.\n\n"
505                "  -h --help       Show this help\n"
506                "  --version       Print version and exit\n"
507                "  --ignore-env    Ignore expected systemd environment\n",
508                program_invocation_short_name,
509                program_invocation_short_name);
510
511         return 0;
512 }
513
514 static void version(void) {
515         puts(PACKAGE_STRING " socket-proxyd");
516 }
517
518 static int parse_argv(int argc, char *argv[], struct proxy *p) {
519
520         enum {
521                 ARG_VERSION = 0x100,
522                 ARG_IGNORE_ENV
523         };
524
525         static const struct option options[] = {
526                 { "help",       no_argument, NULL, 'h'           },
527                 { "version",    no_argument, NULL, ARG_VERSION   },
528                 { "ignore-env", no_argument, NULL, ARG_IGNORE_ENV},
529                 { NULL,         0,           NULL, 0             }
530         };
531
532         int c;
533
534         assert(argc >= 0);
535         assert(argv);
536
537         while ((c = getopt_long(argc, argv, "h", options, NULL)) >= 0) {
538
539                 switch (c) {
540
541                 case 'h':
542                         help();
543                         return 0;
544
545                 case '?':
546                         return -EINVAL;
547
548                 case ARG_VERSION:
549                         version();
550                         return 0;
551
552                 case ARG_IGNORE_ENV:
553                         p->ignore_env = true;
554                         continue;
555
556                 default:
557                         log_error("Unknown option code %c", c);
558                         return -EINVAL;
559                 }
560         }
561
562         if (optind + 1 != argc && optind + 2 != argc) {
563                 log_error("Incorrect number of positional arguments.");
564                 help();
565                 return -EINVAL;
566         }
567
568         p->remote_host = argv[optind];
569         assert(p->remote_host);
570
571         p->remote_is_inet = p->remote_host[0] != '/';
572
573         if (optind == argc - 2) {
574                 if (!p->remote_is_inet) {
575                         log_error("A port or service is not allowed for Unix socket destinations.");
576                         help();
577                         return -EINVAL;
578                 }
579                 p->remote_service = argv[optind + 1];
580                 assert(p->remote_service);
581         } else if (p->remote_is_inet) {
582                 log_error("A port or service is required for IP destinations.");
583                 help();
584                 return -EINVAL;
585         }
586
587         return 1;
588 }
589
590 int main(int argc, char *argv[]) {
591         struct proxy p = {};
592         int r;
593
594         log_parse_environment();
595         log_open();
596
597         r = parse_argv(argc, argv, &p);
598         if (r <= 0)
599                 goto finish;
600
601         p.listen_fd = SD_LISTEN_FDS_START;
602
603         if (!p.ignore_env) {
604                 int n;
605                 n = sd_listen_fds(1);
606                 if (n == 0) {
607                         log_error("Found zero inheritable sockets. Are you sure this is running as a socket-activated service?");
608                         r = EXIT_FAILURE;
609                         goto finish;
610                 } else if (n < 0) {
611                         log_error("Error %d while finding inheritable sockets: %s", n, strerror(-n));
612                         r = EXIT_FAILURE;
613                         goto finish;
614                 } else if (n > 1) {
615                         log_error("Can't listen on more than one socket.");
616                         r = EXIT_FAILURE;
617                         goto finish;
618                 }
619         }
620
621         r = sd_is_socket(p.listen_fd, 0, SOCK_STREAM, 1);
622         if (r < 0) {
623                 log_error("Error %d while checking inherited socket: %s", r, strerror(-r));
624                 goto finish;
625         }
626
627         log_info("Starting the socket activation proxy with listener fd=%d.", p.listen_fd);
628
629         r = run_main_loop(&p);
630
631 finish:
632         return r < 0 ? EXIT_FAILURE : EXIT_SUCCESS;
633 }