chiark / gitweb /
socket-proxyd: Remove datagram research TODO. This proxy will not work with them.
[elogind.git] / src / socket-proxy / socket-proxyd.c
1 /*-*- Mode: C; c-basic-offset: 8; indent-tabs-mode: nil -*-*/
2
3 /***
4   This file is part of systemd.
5
6   Copyright 2013 David Strauss
7
8   systemd is free software; you can redistribute it and/or modify it
9   under the terms of the GNU Lesser General Public License as published by
10   the Free Software Foundation; either version 2.1 of the License, or
11   (at your option) any later version.
12
13   systemd is distributed in the hope that it will be useful, but
14   WITHOUT ANY WARRANTY; without even the implied warranty of
15   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
16   Lesser General Public License for more details.
17
18   You should have received a copy of the GNU Lesser General Public License
19   along with systemd; If not, see <http://www.gnu.org/licenses/>.
20  ***/
21
22 #include <arpa/inet.h>
23 #include <errno.h>
24 #include <getopt.h>
25 #include <stdio.h>
26 #include <stdlib.h>
27 #include <string.h>
28 #include <netdb.h>
29 #include <sys/fcntl.h>
30 #include <sys/socket.h>
31 #include <sys/un.h>
32 #include <unistd.h>
33
34 #include "sd-daemon.h"
35 #include "sd-event.h"
36 #include "log.h"
37 #include "socket-util.h"
38 #include "util.h"
39 #include "event-util.h"
40
41 #define BUFFER_SIZE 16384
42 #define _cleanup_freeaddrinfo_ _cleanup_(freeaddrinfop)
43
44 unsigned int total_clients = 0;
45
46 DEFINE_TRIVIAL_CLEANUP_FUNC(struct addrinfo *, freeaddrinfo);
47
48 struct proxy {
49         int listen_fd;
50         bool ignore_env;
51         bool remote_is_inet;
52         const char *remote_host;
53         const char *remote_service;
54 };
55
56 struct connection {
57         int fd;
58         uint32_t events;
59         sd_event_source *w;
60         struct connection *c_destination;
61         size_t buffer_filled_len;
62         size_t buffer_sent_len;
63         char buffer[BUFFER_SIZE];
64 };
65
66 static void free_connection(struct connection *c) {
67         log_debug("Freeing fd=%d (conn %p).", c->fd, c);
68         sd_event_source_unref(c->w);
69         close_nointr_nofail(c->fd);
70         free(c);
71 }
72
73 static int add_event_to_connection(struct connection *c, uint32_t events) {
74         int r;
75
76         log_debug("Have revents=%d. Adding revents=%d.", c->events, events);
77
78         c->events |= events;
79
80         r = sd_event_source_set_io_events(c->w, c->events);
81         if (r < 0) {
82                 log_error("Error %d setting revents: %s", r, strerror(-r));
83                 return r;
84         }
85
86         r = sd_event_source_set_enabled(c->w, SD_EVENT_ON);
87         if (r < 0) {
88                 log_error("Error %d enabling source: %s", r, strerror(-r));
89                 return r;
90         }
91
92         return 0;
93 }
94
95 static int remove_event_from_connection(struct connection *c, uint32_t events) {
96         int r;
97
98         log_debug("Have revents=%d. Removing revents=%d.", c->events, events);
99
100         c->events &= ~events;
101
102         r = sd_event_source_set_io_events(c->w, c->events);
103         if (r < 0) {
104                 log_error("Error %d setting revents: %s", r, strerror(-r));
105                 return r;
106         }
107
108         if (c->events == 0) {
109             r = sd_event_source_set_enabled(c->w, SD_EVENT_OFF);
110             if (r < 0) {
111                     log_error("Error %d disabling source: %s", r, strerror(-r));
112                     return r;
113             }
114         }
115
116         return 0;
117 }
118
119 static int send_buffer(struct connection *sender) {
120         struct connection *receiver = sender->c_destination;
121         ssize_t len;
122         int r = 0;
123
124         /* We cannot assume that even a partial send() indicates that
125          * the next send() will return EAGAIN or EWOULDBLOCK. Loop until
126          * it does. */
127         while (sender->buffer_filled_len > sender->buffer_sent_len) {
128                 len = send(receiver->fd, sender->buffer + sender->buffer_sent_len, sender->buffer_filled_len - sender->buffer_sent_len, 0);
129                 log_debug("send(%d, ...)=%ld", receiver->fd, len);
130                 if (len < 0) {
131                         if (errno != EWOULDBLOCK && errno != EAGAIN) {
132                                 log_error("Error %d in send to fd=%d: %m", errno, receiver->fd);
133                                 return -errno;
134                         }
135                         else {
136                                 /* send() is in a would-block state. */
137                                 break;
138                         }
139                 }
140
141                 /* len < 0 can't occur here. len == 0 is possible but
142                  * undefined behavior for nonblocking send(). */
143                 assert(len > 0);
144                 sender->buffer_sent_len += len;
145         }
146
147         log_debug("send(%d, ...) completed with %lu bytes still buffered.", receiver->fd, sender->buffer_filled_len - sender->buffer_sent_len);
148
149         /* Detect a would-block state or partial send. */
150         if (sender->buffer_filled_len > sender->buffer_sent_len) {
151
152                 /* If the buffer is full, disable events coming for recv. */
153                 if (sender->buffer_filled_len == BUFFER_SIZE) {
154                     r = remove_event_from_connection(sender, EPOLLIN);
155                     if (r < 0) {
156                             log_error("Error %d disabling EPOLLIN for fd=%d: %s", r, sender->fd, strerror(-r));
157                             return r;
158                     }
159                 }
160
161                 /* Watch for when the recipient can be sent data again. */
162                 r = add_event_to_connection(receiver, EPOLLOUT);
163                 if (r < 0) {
164                         log_error("Error %d enabling EPOLLOUT for fd=%d: %s", r, receiver->fd, strerror(-r));
165                         return r;
166                 }
167                 log_debug("Done with recv for fd=%d. Waiting on send for fd=%d.", sender->fd, receiver->fd);
168                 return r;
169         }
170
171         /* If we sent everything without any issues (would-block or
172          * partial send), the buffer is now empty. */
173         sender->buffer_filled_len = 0;
174         sender->buffer_sent_len = 0;
175
176         /* Enable the sender's receive watcher, in case the buffer was
177          * full and we disabled it. */
178         r = add_event_to_connection(sender, EPOLLIN);
179         if (r < 0) {
180                 log_error("Error %d enabling EPOLLIN for fd=%d: %s", r, sender->fd, strerror(-r));
181                 return r;
182         }
183
184         /* Disable the other side's send watcher, as we have no data to send now. */
185         r = remove_event_from_connection(receiver, EPOLLOUT);
186         if (r < 0) {
187                 log_error("Error %d disabling EPOLLOUT for fd=%d: %s", r, receiver->fd, strerror(-r));
188                 return r;
189         }
190
191         return 0;
192 }
193
194 static int transfer_data_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) {
195         struct connection *c = (struct connection *) userdata;
196         int r = 0;
197         ssize_t len;
198
199         assert(revents & (EPOLLIN | EPOLLOUT));
200         assert(fd == c->fd);
201         assert(s == c->w);
202
203         log_debug("Got event revents=%d from fd=%d (conn %p).", revents, fd, c);
204
205         if (revents & EPOLLIN) {
206                 log_debug("About to recv up to %lu bytes from fd=%d (%lu/BUFFER_SIZE).", BUFFER_SIZE - c->buffer_filled_len, fd, c->buffer_filled_len);
207
208                 /* Receive until the buffer's full, there's no more data,
209                  * or the client/server disconnects. */
210                 while (c->buffer_filled_len < BUFFER_SIZE) {
211                         len = recv(fd, c->buffer + c->buffer_filled_len, BUFFER_SIZE - c->buffer_filled_len, 0);
212                         log_debug("recv(%d, ...)=%ld", fd, len);
213                         if (len < 0) {
214                                 if (errno != EWOULDBLOCK && errno != EAGAIN) {
215                                         log_error("Error %d in recv from fd=%d: %m", errno, fd);
216                                         return -errno;
217                                 }
218                                 else {
219                                         /* recv() is in a blocking state. */
220                                         break;
221                                 }
222                         }
223                         else if (len == 0) {
224                                 log_debug("Clean disconnection from fd=%d", fd);
225                                 total_clients--;
226                                 free_connection(c->c_destination);
227                                 free_connection(c);
228                                 return 0;
229                         }
230
231                         assert(len > 0);
232                         log_debug("Recording that the buffer got %ld more bytes full.", len);
233                         c->buffer_filled_len += len;
234                         log_debug("Buffer now has %ld bytes full.", c->buffer_filled_len);
235                 }
236
237                 /* Try sending the data immediately. */
238                 return send_buffer(c);
239         }
240         else {
241                 return send_buffer(c->c_destination);
242         }
243
244         return r;
245 }
246
247 /* Once sending to the server is ready, set up the real watchers. */
248 static int connected_to_server_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) {
249         struct sd_event *e = NULL;
250         struct connection *c_server_to_client = (struct connection *) userdata;
251         struct connection *c_client_to_server = c_server_to_client->c_destination;
252         int r;
253
254         assert(revents & EPOLLOUT);
255
256         e = sd_event_get(s);
257
258         /* Cancel the initial write watcher for the server. */
259         sd_event_source_unref(s);
260
261         log_debug("Connected to server. Initializing watchers for receiving data.");
262
263         /* A recv watcher for the server. */
264         r = sd_event_add_io(e, c_server_to_client->fd, EPOLLIN, transfer_data_cb, c_server_to_client, &c_server_to_client->w);
265         if (r < 0) {
266                 log_error("Error %d creating recv watcher for fd=%d: %s", r, c_server_to_client->fd, strerror(-r));
267                 goto fail;
268         }
269         c_server_to_client->events = EPOLLIN;
270
271         /* A recv watcher for the client. */
272         r = sd_event_add_io(e, c_client_to_server->fd, EPOLLIN, transfer_data_cb, c_client_to_server, &c_client_to_server->w);
273         if (r < 0) {
274                 log_error("Error %d creating recv watcher for fd=%d: %s", r, c_client_to_server->fd, strerror(-r));
275                 goto fail;
276         }
277         c_client_to_server->events = EPOLLIN;
278
279 goto finish;
280
281 fail:
282         free_connection(c_client_to_server);
283         free_connection(c_server_to_client);
284
285 finish:
286         return r;
287 }
288
289 static int get_server_connection_fd(const struct proxy *proxy) {
290         int server_fd;
291         int r = -EBADF;
292         int len;
293
294         if (proxy->remote_is_inet) {
295                 int s;
296                 _cleanup_freeaddrinfo_ struct addrinfo *result = NULL;
297                 struct addrinfo hints = {.ai_family = AF_UNSPEC,
298                                          .ai_socktype = SOCK_STREAM,
299                                          .ai_flags = AI_PASSIVE};
300
301                 log_debug("Looking up address info for %s:%s", proxy->remote_host, proxy->remote_service);
302                 s = getaddrinfo(proxy->remote_host, proxy->remote_service, &hints, &result);
303                 if (s != 0) {
304                         log_error("getaddrinfo error (%d): %s", s, gai_strerror(s));
305                         return r;
306                 }
307
308                 if (result == NULL) {
309                         log_error("getaddrinfo: no result");
310                         return r;
311                 }
312
313                 /* @TODO: Try connecting to all results instead of just the first. */
314                 server_fd = socket(result->ai_family, result->ai_socktype | SOCK_NONBLOCK, result->ai_protocol);
315                 if (server_fd < 0) {
316                         log_error("Error %d creating socket: %m", errno);
317                         return r;
318                 }
319
320                 r = connect(server_fd, result->ai_addr, result->ai_addrlen);
321                 /* Ignore EINPROGRESS errors because they're expected for a nonblocking socket. */
322                 if (r < 0 && errno != EINPROGRESS) {
323                         log_error("Error %d while connecting to socket %s:%s: %m", errno, proxy->remote_host, proxy->remote_service);
324                         return r;
325                 }
326         }
327         else {
328                 struct sockaddr_un remote;
329
330                 server_fd = socket(AF_UNIX, SOCK_STREAM | SOCK_NONBLOCK, 0);
331                 if (server_fd < 0) {
332                         log_error("Error %d creating socket: %m", errno);
333                         return -EBADFD;
334                 }
335
336                 remote.sun_family = AF_UNIX;
337                 strncpy(remote.sun_path, proxy->remote_host, sizeof(remote.sun_path));
338                 len = strlen(remote.sun_path) + sizeof(remote.sun_family);
339                 r = connect(server_fd, (struct sockaddr *) &remote, len);
340                 if (r < 0 && errno != EINPROGRESS) {
341                         log_error("Error %d while connecting to Unix domain socket %s: %m", errno, proxy->remote_host);
342                         return -EBADFD;
343                 }
344         }
345
346         log_debug("Server connection is fd=%d", server_fd);
347         return server_fd;
348 }
349
350 static int accept_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) {
351         struct proxy *proxy = (struct proxy *) userdata;
352         struct connection *c_server_to_client;
353         struct connection *c_client_to_server = NULL;
354         int r = 0;
355         union sockaddr_union sa;
356         socklen_t salen = sizeof(sa);
357
358         assert(revents & EPOLLIN);
359
360         c_server_to_client = new0(struct connection, 1);
361         if (c_server_to_client == NULL) {
362                 log_oom();
363                 goto fail;
364         }
365
366         c_client_to_server = new0(struct connection, 1);
367         if (c_client_to_server == NULL) {
368                 log_oom();
369                 goto fail;
370         }
371
372         c_server_to_client->fd = get_server_connection_fd(proxy);
373         if (c_server_to_client->fd < 0) {
374                 log_error("Error initiating server connection.");
375                 goto fail;
376         }
377
378         c_client_to_server->fd = accept4(fd, (struct sockaddr *) &sa, &salen, SOCK_NONBLOCK|SOCK_CLOEXEC);
379         if (c_client_to_server->fd < 0) {
380                 log_error("Error accepting client connection.");
381                 goto fail;
382         }
383
384
385         if (sa.sa.sa_family == AF_INET || sa.sa.sa_family == AF_INET6) {
386                 char sa_str[INET6_ADDRSTRLEN];
387                 const char *success;
388
389                 success = inet_ntop(sa.sa.sa_family, &sa.in6.sin6_addr, sa_str, INET6_ADDRSTRLEN);
390                 if (success == NULL)
391                         log_warning("Error %d calling inet_ntop: %m", errno);
392                 else
393                         log_debug("Accepted client connection from %s as fd=%d", sa_str, c_client_to_server->fd);
394         }
395         else {
396                 log_debug("Accepted client connection (non-IP) as fd=%d", c_client_to_server->fd);
397         }
398
399         total_clients++;
400         log_debug("Client fd=%d (conn %p) successfully connected. Total clients: %u", c_client_to_server->fd, c_client_to_server, total_clients);
401         log_debug("Server fd=%d (conn %p) successfully initialized.", c_server_to_client->fd, c_server_to_client);
402
403         /* Initialize watcher for send to server; this shows connectivity. */
404         r = sd_event_add_io(sd_event_get(s), c_server_to_client->fd, EPOLLOUT, connected_to_server_cb, c_server_to_client, &c_server_to_client->w);
405         if (r < 0) {
406                 log_error("Error %d creating connectivity watcher for fd=%d: %s", r, c_server_to_client->fd, strerror(-r));
407                 goto fail;
408         }
409
410         /* Allow lookups of the opposite connection. */
411         c_server_to_client->c_destination = c_client_to_server;
412         c_client_to_server->c_destination = c_server_to_client;
413
414         goto finish;
415
416 fail:
417         log_warning("Accepting a client connection or connecting to the server failed.");
418         free_connection(c_client_to_server);
419         free_connection(c_server_to_client);
420
421 finish:
422         /* Preserve the main loop even if a single proxy setup fails. */
423         return 1;
424 }
425
426 static int run_main_loop(struct proxy *proxy) {
427         _cleanup_event_source_unref_ sd_event_source *w_accept = NULL;
428         _cleanup_event_unref_ sd_event *e = NULL;
429         int r = EXIT_SUCCESS;
430
431         r = sd_event_new(&e);
432         if (r < 0) {
433                 log_error("Failed to allocate event loop: %s", strerror(-r));
434                 return r;
435         }
436
437         r = fd_nonblock(proxy->listen_fd, true);
438         if (r < 0) {
439                 log_error("Failed to make listen file descriptor nonblocking: %s", strerror(-r));
440                 return r;
441         }
442
443         log_debug("Initializing main listener fd=%d", proxy->listen_fd);
444
445         r = sd_event_add_io(e, proxy->listen_fd, EPOLLIN, accept_cb, proxy, &w_accept);
446         if (r < 0) {
447                 log_error("Failed to add event IO source: %s", strerror(-r));
448                 return r;
449         }
450
451         log_debug("Initialized main listener. Entering loop.");
452
453         return sd_event_loop(e);
454 }
455
456 static int help(void) {
457
458         printf("%s hostname-or-ip port-or-service\n"
459                "%s unix-domain-socket-path\n\n"
460                "Inherit a socket. Bidirectionally proxy.\n\n"
461                "  -h --help       Show this help\n"
462                "  --version       Print version and exit\n"
463                "  --ignore-env    Ignore expected systemd environment\n",
464                program_invocation_short_name,
465                program_invocation_short_name);
466
467         return 0;
468 }
469
470 static void version(void) {
471         puts(PACKAGE_STRING " socket-proxyd");
472 }
473
474 static int parse_argv(int argc, char *argv[], struct proxy *p) {
475
476         enum {
477                 ARG_VERSION = 0x100,
478                 ARG_IGNORE_ENV
479         };
480
481         static const struct option options[] = {
482                 { "help",       no_argument, NULL, 'h'           },
483                 { "version",    no_argument, NULL, ARG_VERSION   },
484                 { "ignore-env", no_argument, NULL, ARG_IGNORE_ENV},
485                 { NULL,         0,           NULL, 0             }
486         };
487
488         int c;
489
490         assert(argc >= 0);
491         assert(argv);
492
493         while ((c = getopt_long(argc, argv, "h", options, NULL)) >= 0) {
494
495                 switch (c) {
496
497                 case 'h':
498                         help();
499                         return 0;
500
501                 case '?':
502                         return -EINVAL;
503
504                 case ARG_VERSION:
505                         version();
506                         return 0;
507
508                 case ARG_IGNORE_ENV:
509                         p->ignore_env = true;
510                         continue;
511
512                 default:
513                         log_error("Unknown option code %c", c);
514                         return -EINVAL;
515                 }
516         }
517
518         if (optind + 1 != argc && optind + 2 != argc) {
519                 log_error("Incorrect number of positional arguments.");
520                 help();
521                 return -EINVAL;
522         }
523
524         p->remote_host = argv[optind];
525         assert(p->remote_host);
526
527         p->remote_is_inet = p->remote_host[0] != '/';
528
529         if (optind == argc - 2) {
530                 if (!p->remote_is_inet) {
531                         log_error("A port or service is not allowed for Unix socket destinations.");
532                         help();
533                         return -EINVAL;
534                 }
535                 p->remote_service = argv[optind + 1];
536                 assert(p->remote_service);
537         } else if (p->remote_is_inet) {
538                 log_error("A port or service is required for IP destinations.");
539                 help();
540                 return -EINVAL;
541         }
542
543         return 1;
544 }
545
546 int main(int argc, char *argv[]) {
547         struct proxy p = {};
548         int r;
549
550         log_parse_environment();
551         log_open();
552
553         r = parse_argv(argc, argv, &p);
554         if (r <= 0)
555                 goto finish;
556
557         p.listen_fd = SD_LISTEN_FDS_START;
558
559         if (!p.ignore_env) {
560                 int n;
561                 n = sd_listen_fds(1);
562                 if (n == 0) {
563                         log_error("Found zero inheritable sockets. Are you sure this is running as a socket-activated service?");
564                         r = EXIT_FAILURE;
565                         goto finish;
566                 } else if (n < 0) {
567                         log_error("Error %d while finding inheritable sockets: %s", n, strerror(-n));
568                         r = EXIT_FAILURE;
569                         goto finish;
570                 } else if (n > 1) {
571                         log_error("Can't listen on more than one socket.");
572                         r = EXIT_FAILURE;
573                         goto finish;
574                 }
575         }
576
577         r = sd_is_socket(p.listen_fd, 0, SOCK_STREAM, 1);
578         if (r < 0) {
579                 log_error("Error %d while checking inherited socket: %s", r, strerror(-r));
580                 goto finish;
581         }
582
583         log_info("Starting the socket activation proxy with listener fd=%d.", p.listen_fd);
584
585         r = run_main_loop(&p);
586
587 finish:
588         return r < 0 ? EXIT_FAILURE : EXIT_SUCCESS;
589 }