chiark / gitweb /
Use new mLib function annotations.
[rsync-backup] / rfreezefs.c
1 /* -*-c-*-
2  *
3  * Freeze a file system under remote control
4  *
5  * (c) 2012 Mark Wooding
6  */
7
8 /*----- Licensing notice --------------------------------------------------*
9  *
10  * This file is part of the `rsync-backup' program.
11  *
12  * rsync-backup is free software; you can redistribute it and/or modify
13  * it under the terms of the GNU General Public License as published by
14  * the Free Software Foundation; either version 2 of the License, or
15  * (at your option) any later version.
16  *
17  * rsync-backup is distributed in the hope that it will be useful,
18  * but WITHOUT ANY WARRANTY; without even the implied warranty of
19  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
20  * GNU General Public License for more details.
21  *
22  * You should have received a copy of the GNU General Public License along
23  * with rsync-backup; if not, write to the Free Software Foundation,
24  * Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
25  */
26
27 /*----- Header files ------------------------------------------------------*/
28
29 #include <assert.h>
30 #include <errno.h>
31 #include <limits.h>
32 #include <signal.h>
33 #include <stdarg.h>
34 #include <stdio.h>
35 #include <string.h>
36 #include <stdlib.h>
37 #include <time.h>
38
39 #include <sys/types.h>
40 #include <sys/time.h>
41 #include <sys/select.h>
42 #include <unistd.h>
43 #include <fcntl.h>
44 #include <sys/ioctl.h>
45
46 #include <linux/fs.h>
47
48 #include <sys/socket.h>
49 #include <arpa/inet.h>
50 #include <netinet/in.h>
51 #include <netdb.h>
52
53 #include <mLib/alloc.h>
54 #include <mLib/dstr.h>
55 #include <mLib/base64.h>
56 #include <mLib/fdflags.h>
57 #include <mLib/macros.h>
58 #include <mLib/mdwopt.h>
59 #include <mLib/quis.h>
60 #include <mLib/report.h>
61 #include <mLib/sub.h>
62 #include <mLib/tv.h>
63
64 /*----- Magic constants ---------------------------------------------------*/
65
66 #define COOKIESZ 16                     /* Size of authentication cookie */
67 #define TO_CONNECT 30                   /* Timeout for incoming connection */
68 #define TO_KEEPALIVE 60                 /* Timeout between keepalives */
69
70 /*----- Utility functions -------------------------------------------------*/
71
72 static int getuint(const char *p, const char *q)
73 {
74   unsigned long i;
75   int e = errno;
76   char *qq;
77
78   if (!q) q = p + strlen(p);
79   errno = 0;
80   i = strtoul(p, &qq, 0);
81   if (errno || qq < q || i > INT_MAX)
82     die(1, "invalid integer `%s'", p);
83   errno = e;
84   return ((int)i);
85 }
86
87 #ifdef DEBUG
88 #  define D(x) x
89 #else
90 #  define D(x)
91 #endif
92
93 /*----- Token management --------------------------------------------------*/
94
95 struct token {
96   const char *label;
97   char tok[(COOKIESZ + 2)*4/3 + 1];
98 };
99
100 #define TOKENS(_)                                                       \
101   _(FREEZE)                                                             \
102   _(FROZEN)                                                             \
103   _(KEEPALIVE)                                                          \
104   _(THAW)                                                               \
105   _(THAWED)
106
107 enum {
108 #define ENUM(tok) T_##tok,
109   TOKENS(ENUM)
110 #undef ENUM
111   T_LIMIT
112 };
113
114 enum {
115 #define MASK(tok) TF_##tok = 1u << T_##tok,
116   TOKENS(MASK)
117 #undef ENUM
118   TF_ALL = (1u << T_LIMIT) - 1u
119 };
120
121 static struct token toktab[] = {
122 #define INIT(tok) { #tok },
123   TOKENS(INIT)
124 #undef INIT
125   { 0 }
126 };
127
128 static void inittoks(void)
129 {
130   static struct token *t, *tt;
131   unsigned char buf[COOKIESZ];
132   int fd;
133   ssize_t n;
134   base64_ctx bc;
135   dstr d = DSTR_INIT;
136
137   if ((fd = open("/dev/urandom", O_RDONLY)) < 0)
138     die(2, "open (urandom): %s", strerror(errno));
139
140   for (t = toktab; t->label; t++) {
141   again:
142     n = read(fd, buf, COOKIESZ);
143     if (n < 0) die(2, "read (urandom): %s", strerror(errno));
144     else if (n < COOKIESZ) die(2, "read (urandom): short read");
145     base64_init(&bc);
146     base64_encode(&bc, buf, COOKIESZ, &d);
147     base64_encode(&bc, 0, 0, &d);
148     dstr_putz(&d);
149
150     for (tt = toktab; tt < t; tt++) {
151       if (strcmp(d.buf, tt->tok) == 0)
152         goto again;
153     }
154
155     assert(d.len < sizeof(t->tok));
156     memcpy(t->tok, d.buf, d.len + 1);
157     dstr_reset(&d);
158   }
159 }
160
161 struct tokmatch {
162   unsigned tf;                          /* Possible token matches */
163   size_t o;                             /* Offset into token string */
164   unsigned f;                           /* Flags */
165 #define TMF_CR 1u                       /*   Seen trailing carriage-return */
166 };
167
168 static void tokmatch_init(struct tokmatch *tm)
169   { tm->tf = TF_ALL; tm->o = 0; tm->f = 0; }
170
171 static int tokmatch_update(struct tokmatch *tm, int ch)
172 {
173   const struct token *t;
174   unsigned tf;
175
176   switch (ch) {
177     case '\n':
178       for (t = toktab, tf = 1; t->label; t++, tf <<= 1) {
179         if ((tm->tf & tf) && !t->tok[tm->o])
180           return (tf);
181       }
182       return (-1);
183     case '\r':
184       for (t = toktab, tf = 1; t->label; t++, tf <<= 1) {
185         if ((tm->tf & tf) && !t->tok[tm->o] && !(tm->f & TMF_CR))
186           tm->f |= TMF_CR;
187         else
188           tm->tf &= ~tf;
189       }
190       break;
191     default:
192       for (t = toktab, tf = 1; t->label; t++, tf <<= 1) {
193         if ((tm->tf & tf) && ch != t->tok[tm->o])
194           tm->tf &= ~tf;
195       }
196       tm->o++;
197       break;
198   }
199   return (0);
200 }
201
202 static int writetok(unsigned i, int fd)
203 {
204   static const char nl = '\n';
205   const struct token *t = &toktab[i];
206   size_t n = strlen(t->tok);
207
208   errno = EIO;
209   if (write(fd, t->tok, n) < n ||
210       write(fd, &nl, 1) < 1)
211     return (-1);
212   return (0);
213 }
214
215 /*----- Data structures ---------------------------------------------------*/
216
217 struct client {
218   struct client *next;                  /* Links in the client chain */
219   int fd;                               /* File descriptor for socket */
220   struct tokmatch tm;                   /* Token matching context */
221 };
222
223 /*----- Static variables --------------------------------------------------*/
224
225 static int *fs;                         /* File descriptors for targets */
226 static char **fsname;                   /* File system names */
227 static size_t nfs;                      /* Number of descriptors */
228
229 /*----- Cleanup -----------------------------------------------------------*/
230
231 #define EOM ((char *)0)
232 static void EXECL_LIKE(0) emerg(const char *msg,...)
233 {
234   va_list ap;
235
236 #define MSG(m)                                                          \
237   do { const char *m_ = m; if (write(2, m_, strlen(m_))); } while (0)
238
239   va_start(ap, msg);
240   MSG(QUIS); MSG(": ");
241   do {
242     MSG(msg);
243     msg = va_arg(ap, const char *);
244   } while (msg != EOM);
245   MSG("\n");
246
247 #undef MSG
248 }
249
250 static void partial_cleanup(size_t n)
251 {
252   int i;
253   int bad = 0;
254
255   for (i = 0; i < nfs; i++) {
256     if (fs[i] == -1)
257       emerg("not really thawing ", fsname[i], EOM);
258     else if (fs[i] != -2) {
259       if (ioctl(fs[i], FITHAW, 0)) {
260         emerg("VERY BAD!  failed to thaw ",
261               fsname[i], ": ", strerror(errno), EOM);
262         bad = 1;
263       }
264       close(fs[i]);
265     }
266     fs[i] = -2;
267   }
268   if (bad) _exit(112);
269 }
270
271 static void cleanup(void) { partial_cleanup(nfs); }
272
273 static int sigcatch[] = {
274   SIGINT, SIGQUIT, SIGTERM, SIGHUP, SIGALRM,
275   SIGILL, SIGSEGV, SIGBUS, SIGFPE, SIGABRT
276 };
277
278 static void NORETURN sigmumble(int sig)
279 {
280   sigset_t ss;
281
282   cleanup();
283   emerg(strsignal(sig), EOM);
284
285   signal(sig, SIG_DFL);
286   sigemptyset(&ss); sigaddset(&ss, sig);
287   sigprocmask(SIG_UNBLOCK, &ss, 0);
288   raise(sig);
289   _exit(4);
290 }
291
292 /*----- Help functions ----------------------------------------------------*/
293
294 static void version(FILE *fp)
295   { pquis(fp, "$, " PACKAGE " version " VERSION "\n"); }
296 static void usage(FILE *fp)
297   { pquis(fp, "Usage: $ [-n] [-a ADDR] [-p LOPORT[-HIPORT]] FILSYS ...\n"); }
298
299 static void help(FILE *fp)
300 {
301   version(fp); putc('\n', fp);
302   usage(fp);
303   fputs("\n\
304 Freezes a filesystem temporarily, with some measure of safety.\n\
305 \n\
306 The program listens for connections on a TCP port, and prints a line\n\
307 \n\
308         PORT COOKIE\n\
309 \n\
310 to standard output.  You must connect to this PORT and send the COOKIE\n\
311 followed by a newline within a short period of time.  The filesystems\n\
312 will then be frozen, and `OK' written to the connection.  In order to\n\
313 keep the file system frozen, you must keep the connection open, and\n\
314 feed data into it.  If the connection closes, or no data is received\n\
315 within a set period of time, or the program receives one of a variety\n\
316 of signals or otherwise becomes unhappy, the filesystems are thawed again.\n\
317 \n\
318 Options:\n\
319 \n\
320 -h, --help                      Print this help text.\n\
321 -v, --version                   Print the program version number.\n\
322 -u, --usage                     Print a short usage message.\n\
323 \n\
324 -a, --address=ADDR              Listen only on ADDR.\n\
325 -n, --not-really                Don't really freeze or thaw filesystems.\n\
326 -p, --port-range=LO[-HI]        Select a port number between LO and HI.\n\
327                                   If HI is omitted, choose only LO.\n\
328 ", fp);
329 }
330
331 /*----- Main program ------------------------------------------------------*/
332
333 int main(int argc, char *argv[])
334 {
335   char buf[256];
336   int loport = -1, hiport = -1;
337   int sk, fd, maxfd;
338   struct sockaddr_in sin;
339   socklen_t sasz;
340   struct hostent *h;
341   const char *p, *q;
342   struct timeval now, when, delta;
343   struct client *clients = 0, *c, **cc;
344   const struct token *t;
345   struct tokmatch tm;
346   fd_set fdin;
347   int i;
348   ssize_t n;
349   unsigned f = 0;
350 #define f_bogus 0x01u
351 #define f_notreally 0x02u
352
353   ego(argv[0]);
354   sub_init();
355
356   /* --- Partially initialize the socket address --- */
357
358   sin.sin_family = AF_INET;
359   sin.sin_addr.s_addr = INADDR_ANY;
360   sin.sin_port = 0;
361
362   /* --- Parse the command line --- */
363
364   for (;;) {
365     static struct option opts[] = {
366       { "help",         0,              0,      'h' },
367       { "version",      0,              0,      'v' },
368       { "usage",        0,              0,      'u' },
369       { "address",      OPTF_ARGREQ,    0,      'a' },
370       { "not-really",   0,              0,      'n' },
371       { "port-range",   OPTF_ARGREQ,    0,      'p' },
372       { 0,              0,              0,      0 }
373     };
374
375     if ((i = mdwopt(argc, argv, "hvua:np:", opts, 0, 0, 0)) < 0) break;
376     switch (i) {
377       case 'h': help(stdout); exit(0);
378       case 'v': version(stdout); exit(0);
379       case 'u': usage(stdout); exit(0);
380       case 'a':
381         if ((h = gethostbyname(optarg)) == 0) {
382           die(1, "failed to resolve address `%s': %s",
383               optarg, hstrerror(h_errno));
384         }
385         if (h->h_addrtype != AF_INET)
386           die(1, "unexpected address type resolving `%s'", optarg);
387         assert(h->h_length == sizeof(sin.sin_addr));
388         memcpy(&sin.sin_addr, h->h_addr, sizeof(sin.sin_addr));
389         break;
390       case 'n': f |= f_notreally; break;
391       case 'p':
392         if ((p = strchr(optarg, '-')) == 0)
393           loport = hiport = getuint(optarg, 0);
394         else {
395           loport = getuint(optarg, p);
396           hiport = getuint(p + 1, 0);
397         }
398         break;
399       default: f |= f_bogus; break;
400     }
401   }
402   if (f & f_bogus) { usage(stderr); exit(1); }
403   if (optind >= argc) { usage(stderr); exit(1); }
404
405   /* --- Open the file systems --- */
406
407   nfs = argc - optind;
408   fsname = &argv[optind];
409   fs = xmalloc(nfs*sizeof(*fs));
410   for (i = 0; i < nfs; i++) {
411     if ((fs[i] = open(fsname[i], O_RDONLY)) < 0)
412       die(2, "open (%s): %s", fsname[i], strerror(errno));
413   }
414
415   if (f & f_notreally) {
416     for (i = 0; i < nfs; i++) {
417       close(fs[i]);
418       fs[i] = -1;
419     }
420   }
421
422   /* --- Generate random tokens --- */
423
424   inittoks();
425
426   /* --- Create the listening socket --- */
427
428   if ((sk = socket(PF_INET, SOCK_STREAM, 0)) < 0)
429     die(2, "socket: %s", strerror(errno));
430   i = 1;
431   if (setsockopt(sk, SOL_SOCKET, SO_REUSEADDR, &i, sizeof(i)))
432     die(2, "setsockopt (reuseaddr): %s", strerror(errno));
433   if (fdflags(sk, O_NONBLOCK, O_NONBLOCK, FD_CLOEXEC, FD_CLOEXEC))
434     die(2, "fdflags: %s", strerror(errno));
435   if (loport < 0 || loport == hiport) {
436     if (loport >= 0) sin.sin_port = htons(loport);
437     if (bind(sk, (struct sockaddr *)&sin, sizeof(sin)))
438       die(2, "bind: %s", strerror(errno));
439   } else if (hiport != loport) {
440     for (i = loport; i <= hiport; i++) {
441       sin.sin_port = htons(i);
442       if (bind(sk, (struct sockaddr *)&sin, sizeof(sin)) >= 0) break;
443       else if (errno != EADDRINUSE)
444         die(2, "bind: %s", strerror(errno));
445     }
446     if (i > hiport) die(2, "bind: all ports in use");
447   }
448   if (listen(sk, 5)) die(2, "listen: %s", strerror(errno));
449
450   /* --- Tell the caller how to connect to us, and start the timer --- */
451
452   sasz = sizeof(sin);
453   if (getsockname(sk, (struct sockaddr *)&sin, &sasz))
454     die(2, "getsockname (listen): %s", strerror(errno));
455   printf("PORT %d\n", ntohs(sin.sin_port));
456   for (t = toktab; t->label; t++)
457     printf("TOKEN %s %s\n", t->label, t->tok);
458   printf("READY\n");
459   if (fflush(stdout) || ferror(stdout))
460     die(2, "write (stdout, rubric): %s", strerror(errno));
461   gettimeofday(&now, 0); TV_ADDL(&when, &now, TO_CONNECT, 0);
462
463   /* --- Collect incoming connections, and check for the cookie --- *
464    *
465    * This is the tricky part.
466    */
467
468   for (;;) {
469     FD_ZERO(&fdin);
470     FD_SET(sk, &fdin);
471     maxfd = sk;
472     for (c = clients; c; c = c->next) {
473       FD_SET(c->fd, &fdin);
474       if (c->fd > maxfd) maxfd = c->fd;
475     }
476     TV_SUB(&delta, &when, &now);
477     if (select(maxfd + 1, &fdin, 0, 0, &delta) < 0)
478       die(2, "select (accept): %s", strerror(errno));
479     gettimeofday(&now, 0);
480
481     if (TV_CMP(&now, >=, &when)) die(3, "timeout (accept)");
482
483     if (FD_ISSET(sk, &fdin)) {
484       sasz = sizeof(sin);
485       fd = accept(sk, (struct sockaddr *)&sin, &sasz);
486       if (fd >= 0) {
487         if (fdflags(fd, O_NONBLOCK, O_NONBLOCK, FD_CLOEXEC, FD_CLOEXEC) < 0)
488           die(2, "fdflags: %s", strerror(errno));
489         c = CREATE(struct client);
490         c->next = clients; c->fd = fd; tokmatch_init(&c->tm);
491         clients = c;
492       }
493 #ifdef DEBUG
494       else if (errno != EAGAIN)
495         moan("accept: %s", strerror(errno));
496 #endif
497     }
498
499     for (cc = &clients; *cc;) {
500       c = *cc;
501       if (!FD_ISSET(c->fd, &fdin)) goto next_client;
502       n = read(c->fd, buf, sizeof(buf));
503       if (!n) goto disconn;
504       else if (n < 0) {
505         if (errno == EAGAIN) goto next_client;
506         D( moan("read (client; auth): %s", strerror(errno)); )
507         goto disconn;
508       } else {
509         for (p = buf, q = p + n; p < q; p++) {
510           switch (tokmatch_update(&c->tm, *p)) {
511             case 0: break;
512             case TF_FREEZE: goto connected;
513             default:
514               D( moan("bad token from client"); )
515               goto disconn;
516           }
517         }
518       }
519
520     next_client:
521       cc = &c->next;
522       continue;
523
524     disconn:
525       close(c->fd);
526       *cc = c->next;
527       DESTROY(c);
528       continue;
529     }
530   }
531
532 connected:
533   close(sk); sk = c->fd;
534   while (clients) {
535     if (clients->fd != sk) close(clients->fd);
536     c = clients->next;
537     DESTROY(clients);
538     clients = c;
539   }
540
541   /* --- Establish signal handlers --- *
542    *
543    * Hopefully this will prevent bad things happening if we have an accident.
544    */
545
546   for (i = 0; i < sizeof(sigcatch)/sizeof(sigcatch[0]); i++) {
547     if (signal(sigcatch[i], sigmumble) == SIG_ERR)
548       die(2, "signal (%d): %s", i, strerror(errno));
549   }
550   atexit(cleanup);
551
552   /* --- Prevent the OOM killer from clobbering us --- */
553
554   if ((fd = open("/proc/self/oom_adj", O_WRONLY)) < 0 ||
555       write(fd, "-17\n", 4) < 4 ||
556       close(fd))
557     die(2, "set oom_adj: %s", strerror(errno));
558
559   /* --- Actually freeze the filesystem --- */
560
561   for (i = 0; i < nfs; i++) {
562     if (fs[i] == -1)
563       moan("not really freezing %s", fsname[i]);
564     else {
565       if (ioctl(fs[i], FIFREEZE, 0) < 0) {
566         partial_cleanup(i);
567         die(2, "ioctl (freeze %s): %s", fsname[i], strerror(errno));
568       }
569     }
570   }
571   if (writetok(T_FROZEN, sk)) {
572     cleanup();
573     die(2, "write (frozen): %s", strerror(errno));
574   }
575
576   /* --- Now wait for the other end to detach --- */
577
578   tokmatch_init(&tm);
579   TV_ADDL(&when, &now, TO_KEEPALIVE, 0);
580   for (p++; p < q; p++) {
581     switch (tokmatch_update(&tm, *p)) {
582       case 0: break;
583       case TF_KEEPALIVE: tokmatch_init(&tm); break;
584       case TF_THAW: goto done;
585       default: cleanup(); die(3, "unknown token (keepalive)");
586     }
587   }
588   for (;;) {
589     FD_ZERO(&fdin);
590     FD_SET(sk, &fdin);
591     TV_SUB(&delta, &when, &now);
592     if (select(sk + 1, &fdin, 0, 0, &delta) < 0) {
593       cleanup();
594       die(2, "select (keepalive): %s", strerror(errno));
595     }
596
597     gettimeofday(&now, 0);
598     if (TV_CMP(&now, >, &when)) {
599       cleanup(); die(3, "timeout (keepalive)");
600     }
601     if (FD_ISSET(sk, &fdin)) {
602       n = read(sk, buf, sizeof(buf));
603       if (!n) { cleanup(); die(3, "end-of-file (keepalive)"); }
604       else if (n < 0) {
605         if (errno == EAGAIN) ;
606         else {
607           cleanup();
608           die(2, "read (client, keepalive): %s", strerror(errno));
609         }
610       } else {
611         for (p = buf, q = p + n; p < q; p++) {
612           switch (tokmatch_update(&tm, *p)) {
613             case 0: break;
614             case TF_KEEPALIVE:
615               TV_ADDL(&when, &now, TO_KEEPALIVE, 0);
616               tokmatch_init(&tm);
617               break;
618             case TF_THAW:
619               goto done;
620             default:
621               cleanup();
622               die(3, "unknown token (keepalive)");
623           }
624         }
625       }
626     }
627   }
628
629 done:
630   cleanup();
631   if (writetok(T_THAWED, sk))
632     die(2, "write (thaw): %s", strerror(errno));
633   close(sk);
634   return (0);
635 }
636
637 /*----- That's all, folks -------------------------------------------------*/