chiark / gitweb /
b3ef428cee16c88ec5f8af7df835db32da5a42d7
[elogind.git] / src / activate / activate.c
1 /*-*- Mode: C; c-basic-offset: 8; indent-tabs-mode: nil -*-*/
2
3 /***
4   This file is part of systemd.
5
6   Copyright 2013 Zbigniew JÄ™drzejewski-Szmek
7
8   systemd is free software; you can redistribute it and/or modify it
9   under the terms of the GNU Lesser General Public License as published by
10   the Free Software Foundation; either version 2.1 of the License, or
11   (at your option) any later version.
12
13   systemd is distributed in the hope that it will be useful, but
14   WITHOUT ANY WARRANTY; without even the implied warranty of
15   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
16   Lesser General Public License for more details.
17
18   You should have received a copy of the GNU Lesser General Public License
19   along with systemd; If not, see <http://www.gnu.org/licenses/>.
20 ***/
21
22 #include <unistd.h>
23 #include <fcntl.h>
24 #include <sys/epoll.h>
25 #include <sys/prctl.h>
26 #include <sys/socket.h>
27 #include <sys/wait.h>
28 #include <getopt.h>
29
30 #include <systemd/sd-daemon.h>
31
32 #include "socket-util.h"
33 #include "build.h"
34 #include "log.h"
35 #include "strv.h"
36 #include "macro.h"
37
38 static char** arg_listen = NULL;
39 static bool arg_accept = false;
40 static char** arg_args = NULL;
41 static char** arg_environ = NULL;
42
43 static int add_epoll(int epoll_fd, int fd) {
44         int r;
45         struct epoll_event ev = {EPOLLIN};
46         ev.data.fd = fd;
47
48         assert(epoll_fd >= 0);
49         assert(fd >= 0);
50
51         r = epoll_ctl(epoll_fd, EPOLL_CTL_ADD, fd, &ev);
52         if (r < 0)
53                 log_error("Failed to add event on epoll fd:%d for fd:%d: %s",
54                           epoll_fd, fd, strerror(-r));
55         return r;
56 }
57
58 static int print_socket(const char* desc, int fd) {
59         int r;
60         SocketAddress addr = {
61                 .size = sizeof(union sockaddr_union),
62                 .type = SOCK_STREAM,
63         };
64         int family;
65
66         r = getsockname(fd, &addr.sockaddr.sa, &addr.size);
67         if (r < 0) {
68                 log_warning("Failed to query socket on fd:%d: %m", fd);
69                 return 0;
70         }
71
72         family = socket_address_family(&addr);
73         switch(family) {
74         case AF_INET:
75         case AF_INET6: {
76                 char* _cleanup_free_ a = NULL;
77                 r = socket_address_print(&addr, &a);
78                 if (r < 0)
79                         log_warning("socket_address_print(): %s", strerror(-r));
80                 else
81                         log_info("%s %s address %s",
82                                  desc,
83                                  family == AF_INET ? "IP" : "IPv6",
84                                  a);
85                 break;
86         }
87         default:
88                 log_warning("Connection with unknown family %d", family);
89         }
90
91         return 0;
92 }
93
94 static int make_socket_fd(const char* address, int flags) {
95         _cleanup_free_ char *p = NULL;
96         SocketAddress a;
97         int fd, r;
98
99         r = socket_address_parse(&a, address);
100         if (r < 0) {
101                 log_error("Failed to parse socket: %s", strerror(-r));
102                 return r;
103         }
104
105         fd = socket_address_listen(&a, flags, SOMAXCONN, SOCKET_ADDRESS_DEFAULT, NULL, false, false, 0755, 0644, NULL);
106         if (fd < 0) {
107                 log_error("Failed to listen: %s", strerror(-r));
108                 return fd;
109         }
110
111         r = socket_address_print(&a, &p);
112         if (r < 0) {
113                 log_error("socket_address_print(): %s", strerror(-r));
114                 close_nointr_nofail(fd);
115                 return r;
116         }
117
118         log_info("Listening on %s", p);
119
120         return fd;
121 }
122
123 static int open_sockets(int *epoll_fd, bool accept) {
124         int n, fd, r;
125         int count = 0;
126         char **address;
127
128         n = sd_listen_fds(true);
129         if (n < 0) {
130                 log_error("Failed to read listening file descriptors from environment: %s",
131                           strerror(-n));
132                 return n;
133         }
134         log_info("Received %d descriptors", n);
135
136         for (fd = SD_LISTEN_FDS_START; fd < SD_LISTEN_FDS_START + n; fd++) {
137                 log_debug("Received descriptor fd:%d", fd);
138                 print_socket("Listening on", fd);
139
140                 r = fd_cloexec(fd, arg_accept);
141                 if (r < 0)
142                         return r;
143
144                 count ++;
145         }
146
147         /** Note: we leak some fd's on error here. I doesn't matter
148          *  much, since the program will exit immediately anyway, but
149          *  would be a pain to fix.
150          */
151
152         STRV_FOREACH(address, arg_listen) {
153                 log_info("Opening address %s", *address);
154
155                 fd = make_socket_fd(*address, SOCK_STREAM | (arg_accept*SOCK_CLOEXEC));
156                 if (fd < 0) {
157                         log_error("Failed to open '%s': %s", *address, strerror(-fd));
158                         return fd;
159                 }
160
161                 assert(fd == SD_LISTEN_FDS_START + count);
162                 count ++;
163         }
164
165         *epoll_fd = epoll_create1(EPOLL_CLOEXEC);
166         if (*epoll_fd < 0) {
167                 log_error("Failed to create epoll object: %m");
168                 return -errno;
169         }
170
171
172         for (fd = SD_LISTEN_FDS_START; fd < SD_LISTEN_FDS_START + count; fd++) {
173                 r = add_epoll(*epoll_fd, fd);
174                 if (r < 0)
175                         return r;
176         }
177
178         return count;
179 }
180
181 static int launch(char* name, char **argv, char **env, int fds) {
182         unsigned n_env = 0, length;
183         _cleanup_strv_free_ char **envp = NULL;
184         char **s;
185         static const char* tocopy[] = {"TERM=", "PATH=", "USER=", "HOME="};
186         _cleanup_free_ char *tmp = NULL;
187         unsigned i;
188
189         length = strv_length(arg_environ);
190         /* PATH, TERM, HOME, USER, LISTEN_FDS, LISTEN_PID, NULL */
191         envp = new0(char *, length + 7);
192
193         STRV_FOREACH(s, arg_environ) {
194                 if (strchr(*s, '='))
195                         envp[n_env++] = *s;
196                 else {
197                         _cleanup_free_ char *p = strappend(*s, "=");
198                         if (!p)
199                                 return log_oom();
200                         envp[n_env] = strv_find_prefix(env, p);
201                         if (envp[n_env])
202                                 n_env ++;
203                 }
204         }
205
206         for (i = 0; i < ELEMENTSOF(tocopy); i++) {
207                 envp[n_env] = strv_find_prefix(env, tocopy[i]);
208                 if (envp[n_env])
209                         n_env ++;
210         }
211
212         if ((asprintf((char**)(envp + n_env++), "LISTEN_FDS=%d", fds) < 0) ||
213             (asprintf((char**)(envp + n_env++), "LISTEN_PID=%d", getpid()) < 0))
214                 return log_oom();
215
216         tmp = strv_join(argv, " ");
217         if (!tmp)
218                 return log_oom();
219
220         log_info("Execing %s (%s)", name, tmp);
221         execvpe(name, argv, envp);
222         log_error("Failed to execp %s (%s): %m", name, tmp);
223         return -errno;
224 }
225
226 static int launch1(const char* child, char** argv, char **env, int fd) {
227         pid_t parent_pid, child_pid;
228         int r;
229
230         _cleanup_free_ char *tmp = NULL;
231         tmp = strv_join(argv, " ");
232         if (!tmp)
233                 return log_oom();
234
235         parent_pid = getpid();
236
237         child_pid = fork();
238         if (child_pid < 0) {
239                 log_error("Failed to fork: %m");
240                 return -errno;
241         }
242
243         /* In the child */
244         if (child_pid == 0) {
245                 r = dup2(fd, STDIN_FILENO);
246                 if (r < 0) {
247                         log_error("Failed to dup connection to stdin: %m");
248                         _exit(EXIT_FAILURE);
249                 }
250
251                 r = dup2(fd, STDOUT_FILENO);
252                 if (r < 0) {
253                         log_error("Failed to dup connection to stdout: %m");
254                         _exit(EXIT_FAILURE);
255                 }
256
257                 r = close(fd);
258                 if (r < 0) {
259                         log_error("Failed to close dupped connection: %m");
260                         _exit(EXIT_FAILURE);
261                 }
262
263                 /* Make sure the child goes away when the parent dies */
264                 if (prctl(PR_SET_PDEATHSIG, SIGTERM) < 0)
265                         _exit(EXIT_FAILURE);
266
267                 /* Check whether our parent died before we were able
268                  * to set the death signal */
269                 if (getppid() != parent_pid)
270                         _exit(EXIT_SUCCESS);
271
272                 execvp(child, argv);
273                 log_error("Failed to exec child %s: %m", child);
274                 _exit(EXIT_FAILURE);
275         }
276
277         log_info("Spawned %s (%s) as PID %d", child, tmp, child_pid);
278
279         return 0;
280 }
281
282 static int do_accept(const char* name, char **argv, char **envp, int fd) {
283         SocketAddress addr = {
284                 .size = sizeof(union sockaddr_union),
285                 .type = SOCK_STREAM,
286         };
287         int fd2, r;
288
289         fd2 = accept(fd, &addr.sockaddr.sa, &addr.size);
290         if (fd2 < 0) {
291                 log_error("Failed to accept connection on fd:%d: %m", fd);
292                 return fd2;
293         }
294
295         print_socket("Connection from", fd2);
296
297         r = launch1(name, argv, envp, fd2);
298         return r;
299 }
300
301 /* SIGCHLD handler. */
302 static void sigchld_hdl(int sig, siginfo_t *t, void *data) {
303         log_info("Child %d died with code %d", t->si_pid, t->si_status);
304         /* Wait for a dead child. */
305         waitpid(t->si_pid, NULL, 0);
306 }
307
308 static int install_chld_handler(void) {
309         int r;
310         struct sigaction act;
311         zero(act);
312         act.sa_flags = SA_SIGINFO;
313         act.sa_sigaction = sigchld_hdl;
314
315         r = sigaction(SIGCHLD, &act, 0);
316         if (r < 0)
317                 log_error("Failed to install SIGCHLD handler: %m");
318         return r;
319 }
320
321 static int help(void) {
322         printf("%s [OPTIONS...]\n\n"
323                "Listen on sockets and launch child on connection.\n\n"
324                "Options:\n"
325                "  -l --listen=ADDR     Listen for raw connections at ADDR\n"
326                "  -a --accept          Spawn separate child for each connection\n"
327                "  -h --help            Show this help and exit\n"
328                "  --version            Print version string and exit\n"
329                "\n"
330                "Note: file descriptors from sd_listen_fds() will be passed through.\n"
331                , program_invocation_short_name
332                );
333
334         return 0;
335 }
336
337 static int parse_argv(int argc, char *argv[]) {
338         enum {
339                 ARG_VERSION = 0x100,
340         };
341
342         static const struct option options[] = {
343                 { "help",         no_argument,       NULL, 'h'           },
344                 { "version",      no_argument,       NULL, ARG_VERSION   },
345                 { "listen",       required_argument, NULL, 'l'           },
346                 { "accept",       no_argument,       NULL, 'a'           },
347                 { "environment",  required_argument, NULL, 'E'           },
348                 {}
349         };
350
351         int c;
352
353         assert(argc >= 0);
354         assert(argv);
355
356         while ((c = getopt_long(argc, argv, "+hl:saE:", options, NULL)) >= 0)
357                 switch(c) {
358                 case 'h':
359                         return help();
360
361                 case ARG_VERSION:
362                         puts(PACKAGE_STRING);
363                         puts(SYSTEMD_FEATURES);
364                         return 0 /* done */;
365
366                 case 'l': {
367                         int r = strv_extend(&arg_listen, optarg);
368                         if (r < 0)
369                                 return r;
370
371                         break;
372                 }
373
374                 case 'a':
375                         arg_accept = true;
376                         break;
377
378                 case 'E': {
379                         int r = strv_extend(&arg_environ, optarg);
380                         if (r < 0)
381                                 return r;
382
383                         break;
384                 }
385
386                 case '?':
387                         return -EINVAL;
388
389                 default:
390                         assert_not_reached("Unhandled option");
391                 }
392
393         if (optind == argc) {
394                 log_error("Usage: %s [OPTION...] PROGRAM [OPTION...]",
395                           program_invocation_short_name);
396                 return -EINVAL;
397         }
398
399         arg_args = argv + optind;
400
401         return 1 /* work to do */;
402 }
403
404 int main(int argc, char **argv, char **envp) {
405         int r, n;
406         int epoll_fd = -1;
407
408         log_parse_environment();
409         log_open();
410
411         r = parse_argv(argc, argv);
412         if (r <= 0)
413                 return r == 0 ? EXIT_SUCCESS : EXIT_FAILURE;
414
415         r = install_chld_handler();
416         if (r < 0)
417                 return EXIT_FAILURE;
418
419         n = open_sockets(&epoll_fd, arg_accept);
420         if (n < 0)
421                 return EXIT_FAILURE;
422
423         for (;;) {
424                 struct epoll_event event;
425
426                 r = epoll_wait(epoll_fd, &event, 1, -1);
427                 if (r < 0) {
428                         if (errno == EINTR)
429                                 continue;
430
431                         log_error("epoll_wait() failed: %m");
432                         return EXIT_FAILURE;
433                 }
434
435                 log_info("Communication attempt on fd:%d", event.data.fd);
436                 if (arg_accept) {
437                         r = do_accept(argv[optind], argv + optind, envp,
438                                       event.data.fd);
439                         if (r < 0)
440                                 return EXIT_FAILURE;
441                 } else
442                         break;
443         }
444
445         launch(argv[optind], argv + optind, envp, n);
446
447         return EXIT_SUCCESS;
448 }