+/* -*-c-*-
+ *
+ * Freeze a file system under remote control
+ *
+ * (c) 2012 Mark Wooding
+ */
+
+/*----- Licensing notice --------------------------------------------------*
+ *
+ * This file is part of the `rsync-backup' program.
+ *
+ * rsync-backup is free software; you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation; either version 2 of the License, or
+ * (at your option) any later version.
+ *
+ * rsync-backup is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License along
+ * with rsync-backup; if not, write to the Free Software Foundation,
+ * Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
+ */
+
+/*----- Header files ------------------------------------------------------*/
+
+#include <assert.h>
+#include <errno.h>
+#include <limits.h>
+#include <signal.h>
+#include <stdarg.h>
+#include <stdio.h>
+#include <string.h>
+#include <stdlib.h>
+#include <time.h>
+
+#include <sys/types.h>
+#include <sys/time.h>
+#include <sys/select.h>
+#include <unistd.h>
+#include <fcntl.h>
+#include <sys/ioctl.h>
+
+#include <linux/fs.h>
+
+#include <sys/socket.h>
+#include <arpa/inet.h>
+#include <netinet/in.h>
+#include <netdb.h>
+
+#include <mLib/alloc.h>
+#include <mLib/dstr.h>
+#include <mLib/base64.h>
+#include <mLib/fdflags.h>
+#include <mLib/mdwopt.h>
+#include <mLib/quis.h>
+#include <mLib/report.h>
+#include <mLib/sub.h>
+#include <mLib/tv.h>
+
+/*----- Magic constants ---------------------------------------------------*/
+
+#define COOKIESZ 16 /* Size of authentication cookie */
+#define TO_CONNECT 30 /* Timeout for incoming connection */
+#define TO_KEEPALIVE 60 /* Timeout between keepalives */
+
+/*----- Utility functions -------------------------------------------------*/
+
+static int getuint(const char *p, const char *q)
+{
+ unsigned long i;
+ int e = errno;
+ char *qq;
+
+ if (!q) q = p + strlen(p);
+ errno = 0;
+ i = strtoul(p, &qq, 0);
+ if (errno || qq < q || i > INT_MAX)
+ die(1, "invalid integer `%s'", p);
+ errno = e;
+ return ((int)i);
+}
+
+#ifdef DEBUG
+# define D(x) x
+#else
+# define D(x)
+#endif
+
+/*----- Token management --------------------------------------------------*/
+
+struct token {
+ const char *label;
+ char tok[(COOKIESZ + 2)*4/3 + 1];
+};
+
+#define TOKENS(_) \
+ _(FREEZE) \
+ _(FROZEN) \
+ _(KEEPALIVE) \
+ _(THAW) \
+ _(THAWED)
+
+enum {
+#define ENUM(tok) T_##tok,
+ TOKENS(ENUM)
+#undef ENUM
+ T_LIMIT
+};
+
+enum {
+#define MASK(tok) TF_##tok = 1u << T_##tok,
+ TOKENS(MASK)
+#undef ENUM
+ TF_ALL = (1u << T_LIMIT) - 1u
+};
+
+static struct token toktab[] = {
+#define INIT(tok) { #tok },
+ TOKENS(INIT)
+#undef INIT
+ { 0 }
+};
+
+static void inittoks(void)
+{
+ static struct token *t, *tt;
+ unsigned char buf[COOKIESZ];
+ int fd;
+ ssize_t n;
+ base64_ctx bc;
+ dstr d = DSTR_INIT;
+
+ if ((fd = open("/dev/urandom", O_RDONLY)) < 0)
+ die(2, "open (urandom): %s", strerror(errno));
+
+ for (t = toktab; t->label; t++) {
+ again:
+ n = read(fd, buf, COOKIESZ);
+ if (n < 0) die(2, "read (urandom): %s", strerror(errno));
+ else if (n < COOKIESZ) die(2, "read (urandom): short read");
+ base64_init(&bc);
+ base64_encode(&bc, buf, COOKIESZ, &d);
+ base64_encode(&bc, 0, 0, &d);
+ dstr_putz(&d);
+
+ for (tt = toktab; tt < t; tt++) {
+ if (strcmp(d.buf, tt->tok) == 0)
+ goto again;
+ }
+
+ assert(d.len < sizeof(t->tok));
+ memcpy(t->tok, d.buf, d.len + 1);
+ dstr_reset(&d);
+ }
+}
+
+struct tokmatch {
+ unsigned tf; /* Possible token matches */
+ size_t o; /* Offset into token string */
+ unsigned f; /* Flags */
+#define TMF_CR 1u /* Seen trailing carriage-return */
+};
+
+static void tokmatch_init(struct tokmatch *tm)
+ { tm->tf = TF_ALL; tm->o = 0; tm->f = 0; }
+
+static int tokmatch_update(struct tokmatch *tm, int ch)
+{
+ const struct token *t;
+ unsigned tf;
+
+ switch (ch) {
+ case '\n':
+ for (t = toktab, tf = 1; t->label; t++, tf <<= 1) {
+ if ((tm->tf & tf) && !t->tok[tm->o])
+ return (tf);
+ }
+ return (-1);
+ case '\r':
+ for (t = toktab, tf = 1; t->label; t++, tf <<= 1) {
+ if ((tm->tf & tf) && !t->tok[tm->o] && !(tm->f & TMF_CR))
+ tm->f |= TMF_CR;
+ else
+ tm->tf &= ~tf;
+ }
+ break;
+ default:
+ for (t = toktab, tf = 1; t->label; t++, tf <<= 1) {
+ if ((tm->tf & tf) && ch != t->tok[tm->o])
+ tm->tf &= ~tf;
+ }
+ tm->o++;
+ break;
+ }
+ return (0);
+}
+
+static int writetok(unsigned i, int fd)
+{
+ static const char nl = '\n';
+ const struct token *t = &toktab[i];
+ size_t n = strlen(t->tok);
+
+ errno = EIO;
+ if (write(fd, t->tok, n) < n ||
+ write(fd, &nl, 1) < 1)
+ return (-1);
+ return (0);
+}
+
+/*----- Data structures ---------------------------------------------------*/
+
+struct client {
+ struct client *next; /* Links in the client chain */
+ int fd; /* File descriptor for socket */
+ struct tokmatch tm; /* Token matching context */
+};
+
+/*----- Static variables --------------------------------------------------*/
+
+static int *fs; /* File descriptors for targets */
+static char **fsname; /* File system names */
+static size_t nfs; /* Number of descriptors */
+
+/*----- Cleanup -----------------------------------------------------------*/
+
+#define EOM ((char *)0)
+static void emerg(const char *msg,...)
+{
+ va_list ap;
+
+#define MSG(m) \
+ do { const char *m_ = m; if (write(2, m_, strlen(m_))); } while (0)
+
+ va_start(ap, msg);
+ MSG(QUIS); MSG(": ");
+ do {
+ MSG(msg);
+ msg = va_arg(ap, const char *);
+ } while (msg != EOM);
+ MSG("\n");
+
+#undef MSG
+}
+
+static void partial_cleanup(size_t n)
+{
+ int i;
+ int bad = 0;
+
+ for (i = 0; i < nfs; i++) {
+ if (fs[i] == -1)
+ emerg("not really thawing ", fsname[i], EOM);
+ else if (fs[i] != -2) {
+ if (ioctl(fs[i], FITHAW, 0)) {
+ emerg("VERY BAD! failed to thaw ",
+ fsname[i], ": ", strerror(errno), EOM);
+ bad = 1;
+ }
+ close(fs[i]);
+ }
+ fs[i] = -2;
+ }
+ if (bad) _exit(112);
+}
+
+static void cleanup(void) { partial_cleanup(nfs); }
+
+static int sigcatch[] = {
+ SIGINT, SIGQUIT, SIGTERM, SIGHUP, SIGALRM,
+ SIGILL, SIGSEGV, SIGBUS, SIGFPE, SIGABRT
+};
+
+static void sigmumble(int sig)
+{
+ sigset_t ss;
+
+ cleanup();
+ emerg(strsignal(sig), 0);
+
+ signal(sig, SIG_DFL);
+ sigemptyset(&ss); sigaddset(&ss, sig);
+ sigprocmask(SIG_UNBLOCK, &ss, 0);
+ raise(sig);
+ _exit(4);
+}
+
+/*----- Help functions ----------------------------------------------------*/
+
+static void version(FILE *fp)
+ { pquis(fp, "$, " PACKAGE " version " VERSION "\n"); }
+static void usage(FILE *fp)
+ { pquis(fp, "Usage: $ [-n] [-a ADDR] [-p LOPORT[-HIPORT]] FILSYS ...\n"); }
+
+static void help(FILE *fp)
+{
+ version(fp); putc('\n', fp);
+ usage(fp);
+ fputs("\n\
+Freezes a filesystem temporarily, with some measure of safety.\n\
+\n\
+The program listens for connections on a TCP port, and prints a line\n\
+\n\
+ PORT COOKIE\n\
+\n\
+to standard output. You must connect to this PORT and send the COOKIE\n\
+followed by a newline within a short period of time. The filesystems\n\
+will then be frozen, and `OK' written to the connection. In order to\n\
+keep the file system frozen, you must keep the connection open, and\n\
+feed data into it. If the connection closes, or no data is received\n\
+within a set period of time, or the program receives one of a variety\n\
+of signals or otherwise becomes unhappy, the filesystems are thawed again.\n\
+\n\
+Options:\n\
+\n\
+-h, --help Print this help text.\n\
+-v, --version Print the program version number.\n\
+-u, --usage Print a short usage message.\n\
+\n\
+-a, --address=ADDR Listen only on ADDR.\n\
+-n, --not-really Don't really freeze or thaw filesystems.\n\
+-p, --port-range=LO[-HI] Select a port number between LO and HI.\n\
+ If HI is omitted, choose only LO.\n\
+", fp);
+}
+
+/*----- Main program ------------------------------------------------------*/
+
+int main(int argc, char *argv[])
+{
+ char buf[256];
+ int loport = -1, hiport = -1;
+ int sk, fd, maxfd;
+ struct sockaddr_in sin;
+ socklen_t sasz;
+ struct hostent *h;
+ const char *p, *q;
+ struct timeval now, when, delta;
+ struct client *clients = 0, *c, **cc;
+ const struct token *t;
+ struct tokmatch tm;
+ fd_set fdin;
+ int i;
+ ssize_t n;
+ unsigned f = 0;
+#define f_bogus 0x01u
+#define f_notreally 0x02u
+
+ ego(argv[0]);
+ sub_init();
+
+ /* --- Partially initialize the socket address --- */
+
+ sin.sin_family = AF_INET;
+ sin.sin_addr.s_addr = INADDR_ANY;
+ sin.sin_port = 0;
+
+ /* --- Parse the command line --- */
+
+ for (;;) {
+ static struct option opts[] = {
+ { "help", 0, 0, 'h' },
+ { "version", 0, 0, 'v' },
+ { "usage", 0, 0, 'u' },
+ { "address", OPTF_ARGREQ, 0, 'a' },
+ { "not-really", 0, 0, 'n' },
+ { "port-range", OPTF_ARGREQ, 0, 'p' },
+ { 0, 0, 0, 0 }
+ };
+
+ if ((i = mdwopt(argc, argv, "hvua:np:", opts, 0, 0, 0)) < 0) break;
+ switch (i) {
+ case 'h': help(stdout); exit(0);
+ case 'v': version(stdout); exit(0);
+ case 'u': usage(stdout); exit(0);
+ case 'a':
+ if ((h = gethostbyname(optarg)) == 0) {
+ die(1, "failed to resolve address `%s': %s",
+ optarg, hstrerror(h_errno));
+ }
+ if (h->h_addrtype != AF_INET)
+ die(1, "unexpected address type resolving `%s'", optarg);
+ assert(h->h_length == sizeof(sin.sin_addr));
+ memcpy(&sin.sin_addr, h->h_addr, sizeof(sin.sin_addr));
+ break;
+ case 'n': f |= f_notreally; break;
+ case 'p':
+ if ((p = strchr(optarg, '-')) == 0)
+ loport = hiport = getuint(optarg, 0);
+ else {
+ loport = getuint(optarg, p);
+ hiport = getuint(p + 1, 0);
+ }
+ break;
+ default: f |= f_bogus; break;
+ }
+ }
+ if (f & f_bogus) { usage(stderr); exit(1); }
+ if (optind >= argc) { usage(stderr); exit(1); }
+
+ /* --- Open the file systems --- */
+
+ nfs = argc - optind;
+ fsname = &argv[optind];
+ fs = xmalloc(nfs*sizeof(*fs));
+ for (i = 0; i < nfs; i++) {
+ if ((fs[i] = open(fsname[i], O_RDONLY)) < 0)
+ die(2, "open (%s): %s", fsname[i], strerror(errno));
+ }
+
+ if (f & f_notreally) {
+ for (i = 0; i < nfs; i++) {
+ close(fs[i]);
+ fs[i] = -1;
+ }
+ }
+
+ /* --- Generate random tokens --- */
+
+ inittoks();
+
+ /* --- Create the listening socket --- */
+
+ if ((sk = socket(PF_INET, SOCK_STREAM, 0)) < 0)
+ die(2, "socket: %s", strerror(errno));
+ i = 1;
+ if (setsockopt(sk, SOL_SOCKET, SO_REUSEADDR, &i, sizeof(i)))
+ die(2, "setsockopt (reuseaddr): %s", strerror(errno));
+ if (fdflags(sk, O_NONBLOCK, O_NONBLOCK, FD_CLOEXEC, FD_CLOEXEC))
+ die(2, "fdflags: %s", strerror(errno));
+ if (loport < 0 || loport == hiport) {
+ if (loport >= 0) sin.sin_port = htons(loport);
+ if (bind(sk, (struct sockaddr *)&sin, sizeof(sin)))
+ die(2, "bind: %s", strerror(errno));
+ } else if (hiport != loport) {
+ for (i = loport; i <= hiport; i++) {
+ sin.sin_port = htons(i);
+ if (bind(sk, (struct sockaddr *)&sin, sizeof(sin)) >= 0) break;
+ else if (errno != EADDRINUSE)
+ die(2, "bind: %s", strerror(errno));
+ }
+ if (i > hiport) die(2, "bind: all ports in use");
+ }
+ if (listen(sk, 5)) die(2, "listen: %s", strerror(errno));
+
+ /* --- Tell the caller how to connect to us, and start the timer --- */
+
+ sasz = sizeof(sin);
+ if (getsockname(sk, (struct sockaddr *)&sin, &sasz))
+ die(2, "getsockname (listen): %s", strerror(errno));
+ printf("PORT %d\n", ntohs(sin.sin_port));
+ for (t = toktab; t->label; t++)
+ printf("TOKEN %s %s\n", t->label, t->tok);
+ printf("READY\n");
+ if (fflush(stdout) || ferror(stdout))
+ die(2, "write (stdout, rubric): %s", strerror(errno));
+ gettimeofday(&now, 0); TV_ADDL(&when, &now, TO_CONNECT, 0);
+
+ /* --- Collect incoming connections, and check for the cookie --- *
+ *
+ * This is the tricky part.
+ */
+
+ for (;;) {
+ FD_ZERO(&fdin);
+ FD_SET(sk, &fdin);
+ maxfd = sk;
+ for (c = clients; c; c = c->next) {
+ FD_SET(c->fd, &fdin);
+ if (c->fd > maxfd) maxfd = c->fd;
+ }
+ TV_SUB(&delta, &when, &now);
+ if (select(maxfd + 1, &fdin, 0, 0, &delta) < 0)
+ die(2, "select (accept): %s", strerror(errno));
+ gettimeofday(&now, 0);
+
+ if (TV_CMP(&now, >=, &when)) die(3, "timeout (accept)");
+
+ if (FD_ISSET(sk, &fdin)) {
+ sasz = sizeof(sin);
+ fd = accept(sk, (struct sockaddr *)&sin, &sasz);
+ if (fd >= 0) {
+ if (fdflags(fd, O_NONBLOCK, O_NONBLOCK, FD_CLOEXEC, FD_CLOEXEC) < 0)
+ die(2, "fdflags: %s", strerror(errno));
+ c = CREATE(struct client);
+ c->next = clients; c->fd = fd; tokmatch_init(&c->tm);
+ clients = c;
+ }
+#ifdef DEBUG
+ else if (errno != EAGAIN)
+ moan("accept: %s", strerror(errno));
+#endif
+ }
+
+ for (cc = &clients; *cc;) {
+ c = *cc;
+ if (!FD_ISSET(c->fd, &fdin)) goto next_client;
+ n = read(c->fd, buf, sizeof(buf));
+ if (!n) goto disconn;
+ else if (n < 0) {
+ if (errno == EAGAIN) goto next_client;
+ D( moan("read (client; auth): %s", strerror(errno)); )
+ goto disconn;
+ } else {
+ for (p = buf, q = p + n; p < q; p++) {
+ switch (tokmatch_update(&c->tm, *p)) {
+ case 0: break;
+ case TF_FREEZE: goto connected;
+ default:
+ D( moan("bad token from client"); )
+ goto disconn;
+ }
+ }
+ }
+
+ next_client:
+ cc = &c->next;
+ continue;
+
+ disconn:
+ close(c->fd);
+ *cc = c->next;
+ DESTROY(c);
+ continue;
+ }
+ }
+
+connected:
+ close(sk); sk = c->fd;
+ while (clients) {
+ if (clients->fd != sk) close(clients->fd);
+ c = clients->next;
+ DESTROY(clients);
+ clients = c;
+ }
+
+ /* --- Establish signal handlers --- *
+ *
+ * Hopefully this will prevent bad things happening if we have an accident.
+ */
+
+ for (i = 0; i < sizeof(sigcatch)/sizeof(sigcatch[0]); i++) {
+ if (signal(sigcatch[i], sigmumble) == SIG_ERR)
+ die(2, "signal (%d): %s", i, strerror(errno));
+ }
+ atexit(cleanup);
+
+ /* --- Prevent the OOM killer from clobbering us --- */
+
+ if ((fd = open("/proc/self/oom_adj", O_WRONLY)) < 0 ||
+ write(fd, "-17\n", 4) < 4 ||
+ close(fd))
+ die(2, "set oom_adj: %s", strerror(errno));
+
+ /* --- Actually freeze the filesystem --- */
+
+ for (i = 0; i < nfs; i++) {
+ if (fs[i] == -1)
+ moan("not really freezing %s", fsname[i]);
+ else {
+ if (ioctl(fs[i], FIFREEZE, 0) < 0) {
+ partial_cleanup(i);
+ die(2, "ioctl (freeze %s): %s", fsname[i], strerror(errno));
+ }
+ }
+ }
+ if (writetok(T_FROZEN, sk)) {
+ cleanup();
+ die(2, "write (frozen): %s", strerror(errno));
+ }
+
+ /* --- Now wait for the other end to detach --- */
+
+ tokmatch_init(&tm);
+ TV_ADDL(&when, &now, TO_KEEPALIVE, 0);
+ for (p++; p < q; p++) {
+ switch (tokmatch_update(&tm, *p)) {
+ case 0: break;
+ case TF_KEEPALIVE: tokmatch_init(&tm); break;
+ case TF_THAW: goto done;
+ default: cleanup(); die(3, "unknown token (keepalive)");
+ }
+ }
+ for (;;) {
+ FD_ZERO(&fdin);
+ FD_SET(sk, &fdin);
+ TV_SUB(&delta, &when, &now);
+ if (select(sk + 1, &fdin, 0, 0, &delta) < 0) {
+ cleanup();
+ die(2, "select (keepalive): %s", strerror(errno));
+ }
+
+ gettimeofday(&now, 0);
+ if (TV_CMP(&now, >, &when)) {
+ cleanup(); die(3, "timeout (keepalive)");
+ }
+ if (FD_ISSET(sk, &fdin)) {
+ n = read(sk, buf, sizeof(buf));
+ if (!n) { cleanup(); die(3, "end-of-file (keepalive)"); }
+ else if (n < 0) {
+ if (errno == EAGAIN) ;
+ else {
+ cleanup();
+ die(2, "read (client, keepalive): %s", strerror(errno));
+ }
+ } else {
+ for (p = buf, q = p + n; p < q; p++) {
+ switch (tokmatch_update(&tm, *p)) {
+ case 0: break;
+ case TF_KEEPALIVE:
+ TV_ADDL(&when, &now, TO_KEEPALIVE, 0);
+ tokmatch_init(&tm);
+ break;
+ case TF_THAW:
+ goto done;
+ default:
+ cleanup();
+ die(3, "unknown token (keepalive)");
+ }
+ }
+ }
+ }
+ }
+
+done:
+ cleanup();
+ if (writetok(T_THAWED, sk))
+ die(2, "write (thaw): %s", strerror(errno));
+ close(sk);
+ return (0);
+}
+
+/*----- That's all, folks -------------------------------------------------*/