chiark / gitweb /
move _cleanup_ attribute in front of the type
[elogind.git] / src / activate / activate.c
1 /*-*- Mode: C; c-basic-offset: 8; indent-tabs-mode: nil -*-*/
2
3 /***
4   This file is part of systemd.
5
6   Copyright 2013 Zbigniew JÄ™drzejewski-Szmek
7
8   systemd is free software; you can redistribute it and/or modify it
9   under the terms of the GNU Lesser General Public License as published by
10   the Free Software Foundation; either version 2.1 of the License, or
11   (at your option) any later version.
12
13   systemd is distributed in the hope that it will be useful, but
14   WITHOUT ANY WARRANTY; without even the implied warranty of
15   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
16   Lesser General Public License for more details.
17
18   You should have received a copy of the GNU Lesser General Public License
19   along with systemd; If not, see <http://www.gnu.org/licenses/>.
20 ***/
21
22 #include <unistd.h>
23 #include <fcntl.h>
24 #include <sys/epoll.h>
25 #include <sys/prctl.h>
26 #include <sys/socket.h>
27 #include <sys/wait.h>
28 #include <getopt.h>
29
30 #include <systemd/sd-daemon.h>
31
32 #include "socket-util.h"
33 #include "build.h"
34 #include "log.h"
35 #include "strv.h"
36 #include "macro.h"
37
38 static char** arg_listen = NULL;
39 static bool arg_accept = false;
40 static char** arg_args = NULL;
41 static char** arg_environ = NULL;
42
43 static int add_epoll(int epoll_fd, int fd) {
44         int r;
45         struct epoll_event ev = {EPOLLIN};
46         ev.data.fd = fd;
47
48         assert(epoll_fd >= 0);
49         assert(fd >= 0);
50
51         r = epoll_ctl(epoll_fd, EPOLL_CTL_ADD, fd, &ev);
52         if (r < 0)
53                 log_error("Failed to add event on epoll fd:%d for fd:%d: %s",
54                           epoll_fd, fd, strerror(-r));
55         return r;
56 }
57
58 static int set_nocloexec(int fd) {
59         int flags;
60
61         flags = fcntl(fd, F_GETFD);
62         if (flags < 0) {
63                 log_error("Querying flags for fd:%d: %m", fd);
64                 return -errno;
65         }
66
67         if (!(flags & FD_CLOEXEC))
68                 return 0;
69
70         if (fcntl(fd, F_SETFD, flags & ~FD_CLOEXEC) < 0) {
71                 log_error("Settings flags for fd:%d: %m", fd);
72                 return -errno;
73         }
74
75         return 0;
76 }
77
78 static int print_socket(const char* desc, int fd) {
79         int r;
80         SocketAddress addr = {
81                 .size = sizeof(union sockaddr_union),
82                 .type = SOCK_STREAM,
83         };
84         int family;
85
86         r = getsockname(fd, &addr.sockaddr.sa, &addr.size);
87         if (r < 0) {
88                 log_warning("Failed to query socket on fd:%d: %m", fd);
89                 return 0;
90         }
91
92         family = socket_address_family(&addr);
93         switch(family) {
94         case AF_INET:
95         case AF_INET6: {
96                 char* _cleanup_free_ a = NULL;
97                 r = socket_address_print(&addr, &a);
98                 if (r < 0)
99                         log_warning("socket_address_print(): %s", strerror(-r));
100                 else
101                         log_info("%s %s address %s",
102                                  desc,
103                                  family == AF_INET ? "IP" : "IPv6",
104                                  a);
105                 break;
106         }
107         default:
108                 log_warning("Connection with unknown family %d", family);
109         }
110
111         return 0;
112 }
113
114 static int open_sockets(int *epoll_fd, bool accept) {
115         int n, fd;
116         int count = 0;
117         char **address;
118
119         n = sd_listen_fds(true);
120         if (n < 0) {
121                 log_error("Failed to read listening file descriptors from environment: %s",
122                           strerror(-n));
123                 return n;
124         }
125         log_info("Received %d descriptors", n);
126
127         for (fd = SD_LISTEN_FDS_START; fd < SD_LISTEN_FDS_START + n; fd++) {
128                 log_debug("Received descriptor fd:%d", fd);
129                 print_socket("Listening on", fd);
130
131                 if (!arg_accept) {
132                         int r = set_nocloexec(fd);
133                         if (r < 0)
134                                 return r;
135                 }
136
137                 count ++;
138         }
139
140         /** Note: we leak some fd's on error here. I doesn't matter
141          *  much, since the program will exit immediately anyway, but
142          *  would be a pain to fix.
143          */
144
145         STRV_FOREACH(address, arg_listen) {
146                 log_info("Opening address %s", *address);
147
148                 fd = make_socket_fd(*address, SOCK_STREAM | (arg_accept*SOCK_CLOEXEC));
149                 if (fd < 0) {
150                         log_error("Failed to open '%s': %s", *address, strerror(-fd));
151                         return fd;
152                 }
153
154                 count ++;
155         }
156
157         *epoll_fd = epoll_create1(EPOLL_CLOEXEC);
158         if (*epoll_fd < 0) {
159                 log_error("Failed to create epoll object: %m");
160                 return -errno;
161         }
162
163         for (fd = SD_LISTEN_FDS_START; fd < SD_LISTEN_FDS_START + count; fd++) {
164                 int r = add_epoll(*epoll_fd, fd);
165                 if (r < 0)
166                         return r;
167         }
168
169         return count;
170 }
171
172 static int launch(char* name, char **argv, char **env, int fds) {
173         unsigned n_env = 0, length;
174         _cleanup_strv_free_ char **envp = NULL;
175         char **s;
176         static const char* tocopy[] = {"TERM=", "PATH=", "USER=", "HOME="};
177         _cleanup_free_ char *tmp = NULL;
178         unsigned i;
179
180         length = strv_length(arg_environ);
181         /* PATH, TERM, HOME, USER, LISTEN_FDS, LISTEN_PID, NULL */
182         envp = new(char *, length + 7);
183
184         STRV_FOREACH(s, arg_environ) {
185                 if (strchr(*s, '='))
186                         envp[n_env++] = *s;
187                 else {
188                         _cleanup_free_ char *p = strappend(*s, "=");
189                         if (!p)
190                                 return log_oom();
191                         envp[n_env] = strv_find_prefix(env, p);
192                         if (envp[n_env])
193                                 n_env ++;
194                 }
195         }
196
197         for (i = 0; i < ELEMENTSOF(tocopy); i++) {
198                 envp[n_env] = strv_find_prefix(env, tocopy[i]);
199                 if (envp[n_env])
200                         n_env ++;
201         }
202
203         if ((asprintf((char**)(envp + n_env++), "LISTEN_FDS=%d", fds) < 0) ||
204             (asprintf((char**)(envp + n_env++), "LISTEN_PID=%d", getpid()) < 0))
205                 return log_oom();
206
207         tmp = strv_join(argv, " ");
208         if (!tmp)
209                 return log_oom();
210
211         log_info("Execing %s (%s)", name, tmp);
212         execvpe(name, argv, envp);
213         log_error("Failed to execp %s (%s): %m", name, tmp);
214         return -errno;
215 }
216
217 static int launch1(const char* child, char** argv, char **env, int fd) {
218         pid_t parent_pid, child_pid;
219         int r;
220
221         _cleanup_free_ char *tmp = NULL;
222         tmp = strv_join(argv, " ");
223         if (!tmp)
224                 return log_oom();
225
226         parent_pid = getpid();
227
228         child_pid = fork();
229         if (child_pid < 0) {
230                 log_error("Failed to fork: %m");
231                 return -errno;
232         }
233
234         /* In the child */
235         if (child_pid == 0) {
236                 r = dup2(fd, STDIN_FILENO);
237                 if (r < 0) {
238                         log_error("Failed to dup connection to stdin: %m");
239                         _exit(EXIT_FAILURE);
240                 }
241
242                 r = dup2(fd, STDOUT_FILENO);
243                 if (r < 0) {
244                         log_error("Failed to dup connection to stdout: %m");
245                         _exit(EXIT_FAILURE);
246                 }
247
248                 r = close(fd);
249                 if (r < 0) {
250                         log_error("Failed to close dupped connection: %m");
251                         _exit(EXIT_FAILURE);
252                 }
253
254                 /* Make sure the child goes away when the parent dies */
255                 if (prctl(PR_SET_PDEATHSIG, SIGTERM) < 0)
256                         _exit(EXIT_FAILURE);
257
258                 /* Check whether our parent died before we were able
259                  * to set the death signal */
260                 if (getppid() != parent_pid)
261                         _exit(EXIT_SUCCESS);
262
263                 execvp(child, argv);
264                 log_error("Failed to exec child %s: %m", child);
265                 _exit(EXIT_FAILURE);
266         }
267
268         log_info("Spawned %s (%s) as PID %d", child, tmp, child_pid);
269
270         return 0;
271 }
272
273 static int do_accept(const char* name, char **argv, char **envp, int fd) {
274         SocketAddress addr = {
275                 .size = sizeof(union sockaddr_union),
276                 .type = SOCK_STREAM,
277         };
278         int fd2, r;
279
280         fd2 = accept(fd, &addr.sockaddr.sa, &addr.size);
281         if (fd2 < 0) {
282                 log_error("Failed to accept connection on fd:%d: %m", fd);
283                 return fd2;
284         }
285
286         print_socket("Connection from", fd2);
287
288         r = launch1(name, argv, envp, fd2);
289         return r;
290 }
291
292 /* SIGCHLD handler. */
293 static void sigchld_hdl(int sig, siginfo_t *t, void *data)
294 {
295         log_info("Child %d died with code %d", t->si_pid, t->si_status);
296         /* Wait for a dead child. */
297         waitpid(t->si_pid, NULL, 0);
298 }
299
300 static int install_chld_handler(void) {
301         int r;
302         struct sigaction act;
303         zero(act);
304         act.sa_flags = SA_SIGINFO;
305         act.sa_sigaction = sigchld_hdl;
306
307         r = sigaction(SIGCHLD, &act, 0);
308         if (r < 0)
309                 log_error("Failed to install SIGCHLD handler: %m");
310         return r;
311 }
312
313 static int help(void) {
314         printf("%s [OPTIONS...]\n\n"
315                "Listen on sockets and launch child on connection.\n\n"
316                "Options:\n"
317                "  -l --listen=ADDR     Listen for raw connections at ADDR\n"
318                "  -a --accept          Spawn separate child for each connection\n"
319                "  -h --help            Show this help and exit\n"
320                "  --version            Print version string and exit\n"
321                "\n"
322                "Note: file descriptors from sd_listen_fds() will be passed through.\n"
323                , program_invocation_short_name
324                );
325
326         return 0;
327 }
328
329 static int parse_argv(int argc, char *argv[]) {
330         enum {
331                 ARG_VERSION = 0x100,
332         };
333
334         static const struct option options[] = {
335                 { "help",         no_argument,       NULL, 'h'           },
336                 { "version",      no_argument,       NULL, ARG_VERSION   },
337                 { "listen",       required_argument, NULL, 'l'           },
338                 { "accept",       no_argument,       NULL, 'a'           },
339                 { "environment",  required_argument, NULL, 'E'           },
340                 { NULL,           0,                 NULL, 0             }
341         };
342
343         int c;
344
345         assert(argc >= 0);
346         assert(argv);
347
348         while ((c = getopt_long(argc, argv, "+hl:saE:", options, NULL)) >= 0)
349                 switch(c) {
350                 case 'h':
351                         help();
352                         return 0 /* done */;
353
354                 case ARG_VERSION:
355                         puts(PACKAGE_STRING);
356                         puts(SYSTEMD_FEATURES);
357                         return 0 /* done */;
358
359                 case 'l': {
360                         int r = strv_extend(&arg_listen, optarg);
361                         if (r < 0)
362                                 return r;
363
364                         break;
365                 }
366
367                 case 'a':
368                         arg_accept = true;
369                         break;
370
371                 case 'E': {
372                         int r = strv_extend(&arg_environ, optarg);
373                         if (r < 0)
374                                 return r;
375
376                         break;
377                 }
378
379                 case '?':
380                         return -EINVAL;
381
382                 default:
383                         log_error("Unknown option code %c", c);
384                         return -EINVAL;
385                 }
386
387         if (optind == argc) {
388                 log_error("Usage: %s [OPTION...] PROGRAM [OPTION...]",
389                           program_invocation_short_name);
390                 return -EINVAL;
391         }
392
393         arg_args = argv + optind;
394
395         return 1 /* work to do */;
396 }
397
398 int main(int argc, char **argv, char **envp) {
399         int r, n;
400         int epoll_fd = -1;
401
402         log_set_max_level(LOG_DEBUG);
403         log_show_color(true);
404         log_parse_environment();
405
406         r = parse_argv(argc, argv);
407         if (r <= 0)
408                 return r == 0 ? EXIT_SUCCESS : EXIT_FAILURE;
409
410         r = install_chld_handler();
411         if (r < 0)
412                 return EXIT_FAILURE;
413
414         n = open_sockets(&epoll_fd, arg_accept);
415         if (n < 0)
416                 return EXIT_FAILURE;
417
418         while (true) {
419                 struct epoll_event event;
420
421                 r = epoll_wait(epoll_fd, &event, 1, -1);
422                 if (r < 0) {
423                         if (errno == EINTR)
424                                 continue;
425
426                         log_error("epoll_wait() failed: %m");
427                         return EXIT_FAILURE;
428                 }
429
430                 log_info("Communication attempt on fd:%d", event.data.fd);
431                 if (arg_accept) {
432                         r = do_accept(argv[optind], argv + optind, envp,
433                                       event.data.fd);
434                         if (r < 0)
435                                 return EXIT_FAILURE;
436                 } else
437                         break;
438         }
439
440         launch(argv[optind], argv + optind, envp, n);
441
442         return EXIT_SUCCESS;
443 }