chiark / gitweb /
build-sys: drop a number CFLAGS assignments in Makefile that are pointless
[elogind.git] / src / socket-proxy / socket-proxyd.c
1 /*-*- Mode: C; c-basic-offset: 8; indent-tabs-mode: nil -*-*/
2
3 /***
4   This file is part of systemd.
5
6   Copyright 2013 David Strauss
7
8   systemd is free software; you can redistribute it and/or modify it
9   under the terms of the GNU Lesser General Public License as published by
10   the Free Software Foundation; either version 2.1 of the License, or
11   (at your option) any later version.
12
13   systemd is distributed in the hope that it will be useful, but
14   WITHOUT ANY WARRANTY; without even the implied warranty of
15   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
16   Lesser General Public License for more details.
17
18   You should have received a copy of the GNU Lesser General Public License
19   along with systemd; If not, see <http://www.gnu.org/licenses/>.
20  ***/
21
22 #include <arpa/inet.h>
23 #include <errno.h>
24 #include <getopt.h>
25 #include <stdio.h>
26 #include <stdlib.h>
27 #include <string.h>
28 #include <netdb.h>
29 #include <sys/fcntl.h>
30 #include <sys/socket.h>
31 #include <sys/un.h>
32 #include <unistd.h>
33
34 #include "sd-daemon.h"
35 #include "sd-event.h"
36 #include "log.h"
37 #include "socket-util.h"
38 #include "util.h"
39 #include "event-util.h"
40
41 #define BUFFER_SIZE 16384
42 #define _cleanup_freeaddrinfo_ _cleanup_(freeaddrinfop)
43
44 unsigned int total_clients = 0;
45
46 DEFINE_TRIVIAL_CLEANUP_FUNC(struct addrinfo *, freeaddrinfo);
47
48 struct proxy {
49         int listen_fd;
50         bool ignore_env;
51         bool remote_is_inet;
52         const char *remote_host;
53         const char *remote_service;
54 };
55
56 struct connection {
57         int fd;
58         uint32_t events;
59         sd_event_source *w;
60         struct connection *c_destination;
61         size_t buffer_filled_len;
62         size_t buffer_sent_len;
63         char buffer[BUFFER_SIZE];
64 };
65
66 static void free_connection(struct connection *c) {
67         log_debug("Freeing fd=%d (conn %p).", c->fd, c);
68         sd_event_source_unref(c->w);
69         close_nointr_nofail(c->fd);
70         free(c);
71 }
72
73 static int add_event_to_connection(struct connection *c, uint32_t events) {
74         int r;
75
76         log_debug("Have revents=%d. Adding revents=%d.", c->events, events);
77
78         c->events |= events;
79
80         r = sd_event_source_set_io_events(c->w, c->events);
81         if (r < 0) {
82                 log_error("Error %d setting revents: %s", r, strerror(-r));
83                 return r;
84         }
85
86         r = sd_event_source_set_enabled(c->w, SD_EVENT_ON);
87         if (r < 0) {
88                 log_error("Error %d enabling source: %s", r, strerror(-r));
89                 return r;
90         }
91
92         return 0;
93 }
94
95 static int remove_event_from_connection(struct connection *c, uint32_t events) {
96         int r;
97
98         log_debug("Have revents=%d. Removing revents=%d.", c->events, events);
99
100         c->events &= ~events;
101
102         r = sd_event_source_set_io_events(c->w, c->events);
103         if (r < 0) {
104                 log_error("Error %d setting revents: %s", r, strerror(-r));
105                 return r;
106         }
107
108         if (c->events == 0) {
109             r = sd_event_source_set_enabled(c->w, SD_EVENT_OFF);
110             if (r < 0) {
111                     log_error("Error %d disabling source: %s", r, strerror(-r));
112                     return r;
113             }
114         }
115
116         return 0;
117 }
118
119 static int send_buffer(struct connection *sender) {
120         struct connection *receiver = sender->c_destination;
121         ssize_t len;
122         int r = 0;
123
124         /* We cannot assume that even a partial send() indicates that
125          * the next send() will block. Loop until it does. */
126         while (sender->buffer_filled_len > sender->buffer_sent_len) {
127                 len = send(receiver->fd, sender->buffer + sender->buffer_sent_len, sender->buffer_filled_len - sender->buffer_sent_len, 0);
128                 log_debug("send(%d, ...)=%ld", receiver->fd, len);
129                 if (len < 0) {
130                         if (errno != EWOULDBLOCK && errno != EAGAIN) {
131                                 log_error("Error %d in send to fd=%d: %m", errno, receiver->fd);
132                                 return -errno;
133                         }
134                         else {
135                                 /* send() is in a blocking state. */
136                                 break;
137                         }
138                 }
139
140                 /* len < 0 can't occur here. len == 0 is possible but
141                  * undefined behavior for nonblocking send(). */
142                 assert(len > 0);
143                 sender->buffer_sent_len += len;
144         }
145
146         log_debug("send(%d, ...) completed with %lu bytes still buffered.", receiver->fd, sender->buffer_filled_len - sender->buffer_sent_len);
147
148         /* Detect a would-block state or partial send. */
149         if (sender->buffer_filled_len > sender->buffer_sent_len) {
150
151                 /* If the buffer is full, disable events coming for recv. */
152                 if (sender->buffer_filled_len == BUFFER_SIZE) {
153                     r = remove_event_from_connection(sender, EPOLLIN);
154                     if (r < 0) {
155                             log_error("Error %d disabling EPOLLIN for fd=%d: %s", r, sender->fd, strerror(-r));
156                             return r;
157                     }
158                 }
159
160                 /* Watch for when the recipient can be sent data again. */
161                 r = add_event_to_connection(receiver, EPOLLOUT);
162                 if (r < 0) {
163                         log_error("Error %d enabling EPOLLOUT for fd=%d: %s", r, receiver->fd, strerror(-r));
164                         return r;
165                 }
166                 log_debug("Done with recv for fd=%d. Waiting on send for fd=%d.", sender->fd, receiver->fd);
167                 return r;
168         }
169
170         /* If we sent everything without blocking, the buffer is now empty. */
171         sender->buffer_filled_len = 0;
172         sender->buffer_sent_len = 0;
173
174         /* Enable the sender's receive watcher, in case the buffer was
175          * full and we disabled it. */
176         r = add_event_to_connection(sender, EPOLLIN);
177         if (r < 0) {
178                 log_error("Error %d enabling EPOLLIN for fd=%d: %s", r, sender->fd, strerror(-r));
179                 return r;
180         }
181
182         /* Disable the other side's send watcher, as we have no data to send now. */
183         r = remove_event_from_connection(receiver, EPOLLOUT);
184         if (r < 0) {
185                 log_error("Error %d disabling EPOLLOUT for fd=%d: %s", r, receiver->fd, strerror(-r));
186                 return r;
187         }
188
189         return 0;
190 }
191
192 static int transfer_data_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) {
193         struct connection *c = (struct connection *) userdata;
194         int r = 0;
195         ssize_t len;
196
197         assert(revents & (EPOLLIN | EPOLLOUT));
198         assert(fd == c->fd);
199         assert(s == c->w);
200
201         log_debug("Got event revents=%d from fd=%d (conn %p).", revents, fd, c);
202
203         if (revents & EPOLLIN) {
204                 log_debug("About to recv up to %lu bytes from fd=%d (%lu/BUFFER_SIZE).", BUFFER_SIZE - c->buffer_filled_len, fd, c->buffer_filled_len);
205
206                 /* Receive until the buffer's full, there's no more data,
207                  * or the client/server disconnects. */
208                 while (c->buffer_filled_len < BUFFER_SIZE) {
209                         len = recv(fd, c->buffer + c->buffer_filled_len, BUFFER_SIZE - c->buffer_filled_len, 0);
210                         log_debug("recv(%d, ...)=%ld", fd, len);
211                         if (len < 0) {
212                                 if (errno != EWOULDBLOCK && errno != EAGAIN) {
213                                         log_error("Error %d in recv from fd=%d: %m", errno, fd);
214                                         return -errno;
215                                 }
216                                 else {
217                                         /* recv() is in a blocking state. */
218                                         break;
219                                 }
220                         }
221                         else if (len == 0) {
222                                 log_debug("Clean disconnection from fd=%d", fd);
223                                 total_clients--;
224                                 free_connection(c->c_destination);
225                                 free_connection(c);
226                                 return 0;
227                         }
228
229                         assert(len > 0);
230                         log_debug("Recording that the buffer got %ld more bytes full.", len);
231                         c->buffer_filled_len += len;
232                         log_debug("Buffer now has %ld bytes full.", c->buffer_filled_len);
233                 }
234
235                 /* Try sending the data immediately. */
236                 return send_buffer(c);
237         }
238         else {
239                 return send_buffer(c->c_destination);
240         }
241
242         return r;
243 }
244
245 /* Once sending to the server is unblocked, set up the real watchers. */
246 static int connected_to_server_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) {
247         struct sd_event *e = NULL;
248         struct connection *c_server_to_client = (struct connection *) userdata;
249         struct connection *c_client_to_server = c_server_to_client->c_destination;
250         int r;
251
252         assert(revents & EPOLLOUT);
253
254         e = sd_event_get(s);
255
256         /* Cancel the initial write watcher for the server. */
257         sd_event_source_unref(s);
258
259         log_debug("Connected to server. Initializing watchers for receiving data.");
260
261         /* A recv watcher for the server. */
262         r = sd_event_add_io(e, c_server_to_client->fd, EPOLLIN, transfer_data_cb, c_server_to_client, &c_server_to_client->w);
263         if (r < 0) {
264                 log_error("Error %d creating recv watcher for fd=%d: %s", r, c_server_to_client->fd, strerror(-r));
265                 goto fail;
266         }
267         c_server_to_client->events = EPOLLIN;
268
269         /* A recv watcher for the client. */
270         r = sd_event_add_io(e, c_client_to_server->fd, EPOLLIN, transfer_data_cb, c_client_to_server, &c_client_to_server->w);
271         if (r < 0) {
272                 log_error("Error %d creating recv watcher for fd=%d: %s", r, c_client_to_server->fd, strerror(-r));
273                 goto fail;
274         }
275         c_client_to_server->events = EPOLLIN;
276
277 goto finish;
278
279 fail:
280         free_connection(c_client_to_server);
281         free_connection(c_server_to_client);
282
283 finish:
284         return r;
285 }
286
287 static int get_server_connection_fd(const struct proxy *proxy) {
288         int server_fd;
289         int r = -EBADF;
290         int len;
291
292         if (proxy->remote_is_inet) {
293                 int s;
294                 _cleanup_freeaddrinfo_ struct addrinfo *result = NULL;
295                 struct addrinfo hints = {.ai_family = AF_UNSPEC,
296                                          .ai_socktype = SOCK_STREAM,
297                                          .ai_flags = AI_PASSIVE};
298
299                 log_debug("Looking up address info for %s:%s", proxy->remote_host, proxy->remote_service);
300                 s = getaddrinfo(proxy->remote_host, proxy->remote_service, &hints, &result);
301                 if (s != 0) {
302                         log_error("getaddrinfo error (%d): %s", s, gai_strerror(s));
303                         return r;
304                 }
305
306                 if (result == NULL) {
307                         log_error("getaddrinfo: no result");
308                         return r;
309                 }
310
311                 /* @TODO: Try connecting to all results instead of just the first. */
312                 server_fd = socket(result->ai_family, result->ai_socktype | SOCK_NONBLOCK, result->ai_protocol);
313                 if (server_fd < 0) {
314                         log_error("Error %d creating socket: %m", errno);
315                         return r;
316                 }
317
318                 r = connect(server_fd, result->ai_addr, result->ai_addrlen);
319                 /* Ignore EINPROGRESS errors because they're expected for a nonblocking socket. */
320                 if (r < 0 && errno != EINPROGRESS) {
321                         log_error("Error %d while connecting to socket %s:%s: %m", errno, proxy->remote_host, proxy->remote_service);
322                         return r;
323                 }
324         }
325         else {
326                 struct sockaddr_un remote;
327
328                 server_fd = socket(AF_UNIX, SOCK_STREAM | SOCK_NONBLOCK, 0);
329                 if (server_fd < 0) {
330                         log_error("Error %d creating socket: %m", errno);
331                         return -EBADFD;
332                 }
333
334                 remote.sun_family = AF_UNIX;
335                 strncpy(remote.sun_path, proxy->remote_host, sizeof(remote.sun_path));
336                 len = strlen(remote.sun_path) + sizeof(remote.sun_family);
337                 r = connect(server_fd, (struct sockaddr *) &remote, len);
338                 if (r < 0 && errno != EINPROGRESS) {
339                         log_error("Error %d while connecting to Unix domain socket %s: %m", errno, proxy->remote_host);
340                         return -EBADFD;
341                 }
342         }
343
344         log_debug("Server connection is fd=%d", server_fd);
345         return server_fd;
346 }
347
348 static int accept_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) {
349         struct proxy *proxy = (struct proxy *) userdata;
350         struct connection *c_server_to_client;
351         struct connection *c_client_to_server;
352         int r = 0;
353         union sockaddr_union sa;
354         socklen_t salen = sizeof(sa);
355
356         assert(revents & EPOLLIN);
357
358         c_server_to_client = new0(struct connection, 1);
359         if (c_server_to_client == NULL) {
360                 log_oom();
361                 goto fail;
362         }
363
364         c_client_to_server = new0(struct connection, 1);
365         if (c_client_to_server == NULL) {
366                 log_oom();
367                 goto fail;
368         }
369
370         c_server_to_client->fd = get_server_connection_fd(proxy);
371         if (c_server_to_client->fd < 0) {
372                 log_error("Error initiating server connection.");
373                 goto fail;
374         }
375
376         c_client_to_server->fd = accept4(fd, (struct sockaddr *) &sa, &salen, SOCK_NONBLOCK|SOCK_CLOEXEC);
377         if (c_client_to_server->fd < 0) {
378                 log_error("Error accepting client connection.");
379                 goto fail;
380         }
381
382
383         if (sa.sa.sa_family == AF_INET || sa.sa.sa_family == AF_INET6) {
384                 char sa_str[INET6_ADDRSTRLEN];
385                 const char *success;
386
387                 success = inet_ntop(sa.sa.sa_family, &sa.in6.sin6_addr, sa_str, INET6_ADDRSTRLEN);
388                 if (success == NULL)
389                         log_warning("Error %d calling inet_ntop: %m", errno);
390                 else
391                         log_debug("Accepted client connection from %s as fd=%d", sa_str, c_client_to_server->fd);
392         }
393         else {
394                 log_debug("Accepted client connection (non-IP) as fd=%d", c_client_to_server->fd);
395         }
396
397         total_clients++;
398         log_debug("Client fd=%d (conn %p) successfully connected. Total clients: %u", c_client_to_server->fd, c_client_to_server, total_clients);
399         log_debug("Server fd=%d (conn %p) successfully initialized.", c_server_to_client->fd, c_server_to_client);
400
401         /* Initialize watcher for send to server; this shows connectivity. */
402         r = sd_event_add_io(sd_event_get(s), c_server_to_client->fd, EPOLLOUT, connected_to_server_cb, c_server_to_client, &c_server_to_client->w);
403         if (r < 0) {
404                 log_error("Error %d creating connectivity watcher for fd=%d: %s", r, c_server_to_client->fd, strerror(-r));
405                 goto fail;
406         }
407
408         /* Allow lookups of the opposite connection. */
409         c_server_to_client->c_destination = c_client_to_server;
410         c_client_to_server->c_destination = c_server_to_client;
411
412         goto finish;
413
414 fail:
415         log_warning("Accepting a client connection or connecting to the server failed.");
416         free_connection(c_client_to_server);
417         free_connection(c_server_to_client);
418
419 finish:
420         /* Preserve the main loop even if a single proxy setup fails. */
421         return 1;
422 }
423
424 static int run_main_loop(struct proxy *proxy) {
425         _cleanup_event_source_unref_ sd_event_source *w_accept = NULL;
426         _cleanup_event_unref_ sd_event *e = NULL;
427         int r = EXIT_SUCCESS;
428
429         r = sd_event_new(&e);
430         if (r < 0) {
431                 log_error("Failed to allocate event loop: %s", strerror(-r));
432                 return r;
433         }
434
435         r = fd_nonblock(proxy->listen_fd, true);
436         if (r < 0) {
437                 log_error("Failed to make listen file descriptor non-blocking: %s", strerror(-r));
438                 return r;
439         }
440
441         log_debug("Initializing main listener fd=%d", proxy->listen_fd);
442
443         r = sd_event_add_io(e, proxy->listen_fd, EPOLLIN, accept_cb, proxy, &w_accept);
444         if (r < 0) {
445                 log_error("Failed to add event IO source: %s", strerror(-r));
446                 return r;
447         }
448
449         log_debug("Initialized main listener. Entering loop.");
450
451         return sd_event_loop(e);
452 }
453
454 static int help(void) {
455
456         printf("%s hostname-or-ip port-or-service\n"
457                "%s unix-domain-socket-path\n\n"
458                "Inherit a socket. Bidirectionally proxy.\n\n"
459                "  -h --help       Show this help\n"
460                "  --version       Print version and exit\n"
461                "  --ignore-env    Ignore expected systemd environment\n",
462                program_invocation_short_name,
463                program_invocation_short_name);
464
465         return 0;
466 }
467
468 static void version(void) {
469         puts(PACKAGE_STRING " socket-proxyd");
470 }
471
472 static int parse_argv(int argc, char *argv[], struct proxy *p) {
473
474         enum {
475                 ARG_VERSION = 0x100,
476                 ARG_IGNORE_ENV
477         };
478
479         static const struct option options[] = {
480                 { "help",       no_argument, NULL, 'h'           },
481                 { "version",    no_argument, NULL, ARG_VERSION   },
482                 { "ignore-env", no_argument, NULL, ARG_IGNORE_ENV},
483                 { NULL,         0,           NULL, 0             }
484         };
485
486         int c;
487
488         assert(argc >= 0);
489         assert(argv);
490
491         while ((c = getopt_long(argc, argv, "h", options, NULL)) >= 0) {
492
493                 switch (c) {
494
495                 case 'h':
496                         help();
497                         return 0;
498
499                 case '?':
500                         return -EINVAL;
501
502                 case ARG_VERSION:
503                         version();
504                         return 0;
505
506                 case ARG_IGNORE_ENV:
507                         p->ignore_env = true;
508                         continue;
509
510                 default:
511                         log_error("Unknown option code %c", c);
512                         return -EINVAL;
513                 }
514         }
515
516         if (optind + 1 != argc && optind + 2 != argc) {
517                 log_error("Incorrect number of positional arguments.");
518                 help();
519                 return -EINVAL;
520         }
521
522         p->remote_host = argv[optind];
523         assert(p->remote_host);
524
525         p->remote_is_inet = p->remote_host[0] != '/';
526
527         if (optind == argc - 2) {
528                 if (!p->remote_is_inet) {
529                         log_error("A port or service is not allowed for Unix socket destinations.");
530                         help();
531                         return -EINVAL;
532                 }
533                 p->remote_service = argv[optind + 1];
534                 assert(p->remote_service);
535         } else if (p->remote_is_inet) {
536                 log_error("A port or service is required for IP destinations.");
537                 help();
538                 return -EINVAL;
539         }
540
541         return 1;
542 }
543
544 int main(int argc, char *argv[]) {
545         struct proxy p = {};
546         int r;
547
548         log_parse_environment();
549         log_open();
550
551         r = parse_argv(argc, argv, &p);
552         if (r <= 0)
553                 goto finish;
554
555         p.listen_fd = SD_LISTEN_FDS_START;
556
557         if (!p.ignore_env) {
558                 int n;
559                 n = sd_listen_fds(1);
560                 if (n == 0) {
561                         log_error("Found zero inheritable sockets. Are you sure this is running as a socket-activated service?");
562                         r = EXIT_FAILURE;
563                         goto finish;
564                 } else if (n < 0) {
565                         log_error("Error %d while finding inheritable sockets: %s", n, strerror(-n));
566                         r = EXIT_FAILURE;
567                         goto finish;
568                 } else if (n > 1) {
569                         log_error("Can't listen on more than one socket.");
570                         r = EXIT_FAILURE;
571                         goto finish;
572                 }
573         }
574
575         /* @TODO: Check if this proxy can work with datagram sockets. */
576         r = sd_is_socket(p.listen_fd, 0, SOCK_STREAM, 1);
577         if (r < 0) {
578                 log_error("Error %d while checking inherited socket: %s", r, strerror(-r));
579                 goto finish;
580         }
581
582         log_info("Starting the socket activation proxy with listener fd=%d.", p.listen_fd);
583
584         r = run_main_loop(&p);
585
586 finish:
587         return r < 0 ? EXIT_FAILURE : EXIT_SUCCESS;
588 }