chiark / gitweb /
make-secnet-sites: new "include" keyword
[secnet.git] / slip.c
diff --git a/slip.c b/slip.c
index afc0528..a296e42 100644 (file)
--- a/slip.c
+++ b/slip.c
 #include "secnet.h"
 #include "util.h"
 #include "netlink.h"
+#include "process.h"
 #include <stdio.h>
 #include <string.h>
 #include <unistd.h>
+#include <errno.h>
+#include <fcntl.h>
 
 #define SLIP_END    192
 #define SLIP_ESC    219
 #define SLIP_ESCEND 220
 #define SLIP_ESCESC 221
 
-/* Connection to the kernel through userv-ipif */
-
-struct userv {
+struct slip {
     struct netlink nl;
-    int txfd; /* We transmit to userv */
-    int rxfd; /* We receive from userv */
-    string_t userv_path;
-    string_t service_user;
-    string_t service_name;
-    uint32_t txbuflen;
     struct buffer_if *buff; /* We unstuff received packets into here
-                              and send them to the site code. */
+                              and send them to the netlink code. */
     bool_t pending_esc;
     netlink_deliver_fn *netlink_to_tunnel;
-    uint32_t local_address; /* host interface address */
+    uint32_t local_address;
 };
 
-static int userv_beforepoll(void *sst, struct pollfd *fds, int *nfds_io,
-                           int *timeout_io, const struct timeval *tv_now,
-                           uint64_t *now)
-{
-    struct userv *st=sst;
-    *nfds_io=2;
-    fds[0].fd=st->txfd;
-    fds[0].events=POLLERR; /* Might want to pick up POLLOUT sometime */
-    fds[1].fd=st->rxfd;
-    fds[1].events=POLLIN|POLLERR|POLLHUP;
-    return 0;
-}
+/* Generic SLIP mangling code */
 
-static void userv_afterpoll(void *sst, struct pollfd *fds, int nfds,
-                           const struct timeval *tv_now, uint64_t *now)
+static void slip_stuff(struct slip *st, struct buffer_if *buf, int fd)
 {
-    struct userv *st=sst;
-    uint8_t rxbuf[DEFAULT_BUFSIZE];
-    int l,i;
-
-    if (fds[1].revents&POLLERR) {
-       Message(M_ERROR,"%s: userv_afterpoll: hup!\n",st->nl.name);
-    }
-    if (fds[1].revents&POLLIN) {
-       l=read(st->rxfd,rxbuf,DEFAULT_BUFSIZE);
-       if (l<0) {
-           fatal_perror("%s: userv_afterpoll: read(rxfd)",st->nl.name);
-       }
-       if (l==0) {
-           fatal("%s: userv_afterpoll: read(rxfd)=0; userv gone away?\n",
-                 st->nl.name);
-       }
-       /* XXX really crude unstuff code */
-       /* XXX check for buffer overflow */
-       BUF_ASSERT_USED(st->buff);
-       for (i=0; i<l; i++) {
-           if (st->pending_esc) {
-               st->pending_esc=False;
-               switch(rxbuf[i]) {
-               case SLIP_ESCEND:
-                   *(uint8_t *)buf_append(st->buff,1)=SLIP_END;
-                   break;
-               case SLIP_ESCESC:
-                   *(uint8_t *)buf_append(st->buff,1)=SLIP_ESC;
-                   break;
-               default:
-                   fatal("userv_afterpoll: bad SLIP escape character\n");
-               }
-           } else {
-               switch (rxbuf[i]) {
-               case SLIP_END:
-                   if (st->buff->size>0) {
-                       st->netlink_to_tunnel(&st->nl,NULL,
-                                             st->buff);
-                       BUF_ALLOC(st->buff,"userv_afterpoll");
-                   }
-                   buffer_init(st->buff,st->nl.max_start_pad);
-                   break;
-               case SLIP_ESC:
-                   st->pending_esc=True;
-                   break;
-               default:
-                   *(uint8_t *)buf_append(st->buff,1)=rxbuf[i];
-                   break;
-               }
-           }
-       }
-    }
-}
-
-/* Send buf to the kernel. Free buf before returning. */
-static void userv_deliver_to_kernel(void *sst, void *cid,
-                                   struct buffer_if *buf)
-{
-    struct userv *st=sst;
     uint8_t txbuf[DEFAULT_BUFSIZE];
     uint8_t *i;
-    uint32_t j;
+    int32_t j=0;
 
     BUF_ASSERT_USED(buf);
 
-    /* Spit the packet at userv-ipif: SLIP start marker, then
-       bytestuff the packet, then SLIP end marker */
-    /* XXX crunchy bytestuff code */
-    j=0;
+    /* There's probably a much more efficient way of implementing this */
     txbuf[j++]=SLIP_END;
     for (i=buf->start; i<(buf->start+buf->size); i++) {
        switch (*i) {
@@ -132,102 +53,319 @@ static void userv_deliver_to_kernel(void *sst, void *cid,
            txbuf[j++]=*i;
            break;
        }
+       if ((j+2)>DEFAULT_BUFSIZE) {
+           if (write(fd,txbuf,j)<0) {
+               fatal_perror("slip_stuff: write()");
+           }
+           j=0;
+       }
     }
     txbuf[j++]=SLIP_END;
-    if (write(st->txfd,txbuf,j)<0) {
-       fatal_perror("userv_deliver_to_kernel: write()");
+    if (write(fd,txbuf,j)<0) {
+       fatal_perror("slip_stuff: write()");
     }
     BUF_FREE(buf);
 }
 
-static void userv_phase_hook(void *sst, uint32_t newphase)
+static void slip_unstuff(struct slip *st, uint8_t *buf, uint32_t l)
+{
+    uint32_t i;
+
+    BUF_ASSERT_USED(st->buff);
+    for (i=0; i<l; i++) {
+       if (st->pending_esc) {
+           st->pending_esc=False;
+           switch(buf[i]) {
+           case SLIP_ESCEND:
+               *(uint8_t *)buf_append(st->buff,1)=SLIP_END;
+               break;
+           case SLIP_ESCESC:
+               *(uint8_t *)buf_append(st->buff,1)=SLIP_ESC;
+               break;
+           default:
+               fatal("userv_afterpoll: bad SLIP escape character");
+           }
+       } else {
+           switch (buf[i]) {
+           case SLIP_END:
+               if (st->buff->size>0) {
+                   st->netlink_to_tunnel(&st->nl,st->buff);
+                   BUF_ALLOC(st->buff,"userv_afterpoll");
+               }
+               buffer_init(st->buff,st->nl.max_start_pad);
+               break;
+           case SLIP_ESC:
+               st->pending_esc=True;
+               break;
+           default:
+               *(uint8_t *)buf_append(st->buff,1)=buf[i];
+               break;
+           }
+       }
+    }
+}
+
+static void slip_init(struct slip *st, struct cloc loc, dict_t *dict,
+                     cstring_t name, netlink_deliver_fn *to_host)
+{
+    st->netlink_to_tunnel=
+       netlink_init(&st->nl,st,loc,dict,
+                    "netlink-userv-ipif",NULL,to_host);
+    st->buff=find_cl_if(dict,"buffer",CL_BUFFER,True,"name",loc);
+    st->local_address=string_item_to_ipaddr(
+       dict_find_item(dict,"local-address", True, name, loc),"netlink");
+    BUF_ALLOC(st->buff,"slip_init");
+    st->pending_esc=False;
+}
+
+/* Connection to the kernel through userv-ipif */
+
+struct userv {
+    struct slip slip;
+    int txfd; /* We transmit to userv */
+    int rxfd; /* We receive from userv */
+    cstring_t userv_path;
+    cstring_t service_user;
+    cstring_t service_name;
+    pid_t pid;
+    bool_t expecting_userv_exit;
+};
+
+static int userv_beforepoll(void *sst, struct pollfd *fds, int *nfds_io,
+                           int *timeout_io)
 {
     struct userv *st=sst;
-    pid_t child;
+
+    if (st->rxfd!=-1) {
+       *nfds_io=2;
+       fds[0].fd=st->txfd;
+       fds[0].events=0; /* Might want to pick up POLLOUT sometime */
+       fds[1].fd=st->rxfd;
+       fds[1].events=POLLIN;
+    } else {
+       *nfds_io=0;
+    }
+    return 0;
+}
+
+static void userv_afterpoll(void *sst, struct pollfd *fds, int nfds)
+{
+    struct userv *st=sst;
+    uint8_t rxbuf[DEFAULT_BUFSIZE];
+    int l;
+
+    if (nfds==0) return;
+
+    if (fds[1].revents&POLLERR) {
+       Message(M_ERR,"%s: userv_afterpoll: POLLERR!\n",st->slip.nl.name);
+    }
+    if (fds[1].revents&POLLIN) {
+       l=read(st->rxfd,rxbuf,DEFAULT_BUFSIZE);
+       if (l<0) {
+           if (errno!=EINTR)
+               fatal_perror("%s: userv_afterpoll: read(rxfd)",
+                            st->slip.nl.name);
+       } else if (l==0) {
+           fatal("%s: userv_afterpoll: read(rxfd)=0; userv gone away?",
+                 st->slip.nl.name);
+       } else slip_unstuff(&st->slip,rxbuf,l);
+    }
+}
+
+/* Send buf to the kernel. Free buf before returning. */
+static void userv_deliver_to_kernel(void *sst, struct buffer_if *buf)
+{
+    struct userv *st=sst;
+
+    slip_stuff(&st->slip,buf,st->txfd);
+}
+
+static void userv_userv_callback(void *sst, pid_t pid, int status)
+{
+    struct userv *st=sst;
+
+    if (pid!=st->pid) {
+       Message(M_WARNING,"userv_callback called unexpectedly with pid %d "
+               "(expected %d)\n",pid,st->pid);
+       return;
+    }
+    if (!st->expecting_userv_exit) {
+       if (WIFEXITED(status)) {
+           fatal("%s: userv exited unexpectedly with status %d",
+                 st->slip.nl.name,WEXITSTATUS(status));
+       } else if (WIFSIGNALED(status)) {
+           fatal("%s: userv exited unexpectedly: uncaught signal %d",
+                 st->slip.nl.name,WTERMSIG(status));
+       } else {
+           fatal("%s: userv stopped unexpectedly");
+       }
+    }
+    Message(M_WARNING,"%s: userv subprocess died with status %d\n",
+           st->slip.nl.name,WEXITSTATUS(status));
+    st->pid=0;
+}
+
+struct userv_entry_rec {
+    cstring_t path;
+    const char **argv;
+    int in;
+    int out;
+    /* XXX perhaps we should collect and log stderr? */
+};
+
+static void userv_entry(void *sst)
+{
+    struct userv_entry_rec *st=sst;
+
+    dup2(st->in,0);
+    dup2(st->out,1);
+
+    /* XXX close all other fds */
+    setsid();
+    /* XXX We really should strdup() all of argv[] but because we'll just
+       exit anyway if execvp() fails it doesn't seem worth bothering. */
+    execvp(st->path,(char *const*)st->argv);
+    perror("userv-entry: execvp()");
+    exit(1);
+}
+
+static void userv_invoke_userv(struct userv *st)
+{
+    struct userv_entry_rec *er;
     int c_stdin[2];
     int c_stdout[2];
     string_t addrs;
     string_t nets;
     string_t s;
-    struct netlink_route *r;
-    int i;
+    struct netlink_client *r;
+    struct ipset *allnets;
+    struct subnet_list *snets;
+    int i, nread;
+    uint8_t confirm;
+
+    if (st->pid) {
+       fatal("userv_invoke_userv: already running");
+    }
 
     /* This is where we actually invoke userv - all the networks we'll
        be using should already have been registered. */
 
-    addrs=safe_malloc(512,"userv_phase_hook:addrs");
-    snprintf(addrs,512,"%s,%s,%d,slip",ipaddr_to_string(st->local_address),
-            ipaddr_to_string(st->nl.secnet_address),st->nl.mtu);
+    addrs=safe_malloc(512,"userv_invoke_userv:addrs");
+    snprintf(addrs,512,"%s,%s,%d,slip",
+            ipaddr_to_string(st->slip.local_address),
+            ipaddr_to_string(st->slip.nl.secnet_address),st->slip.nl.mtu);
 
-    nets=safe_malloc(1024,"userv_phase_hook:nets");
-    *nets=0;
-    r=st->nl.routes;
-    for (i=0; i<st->nl.n_routes; i++) {
-       if (r[i].up) {
-           r[i].kup=True;
-           s=subnet_to_string(&r[i].net);
-           strcat(nets,s);
-           strcat(nets,",");
-           free(s);
+    allnets=ipset_new();
+    for (r=st->slip.nl.clients; r; r=r->next) {
+       if (r->up) {
+           struct ipset *nan;
+           r->kup=True;
+           nan=ipset_union(allnets,r->networks);
+           ipset_free(allnets);
+           allnets=nan;
        }
     }
+    snets=ipset_to_subnet_list(allnets);
+    ipset_free(allnets);
+    nets=safe_malloc(20*snets->entries,"userv_invoke_userv:nets");
+    *nets=0;
+    for (i=0; i<snets->entries; i++) {
+       s=subnet_to_string(snets->list[i]);
+       strcat(nets,s);
+       strcat(nets,",");
+       free(s);
+    }
     nets[strlen(nets)-1]=0;
+    subnet_list_free(snets);
 
-    Message(M_INFO,"%s: about to invoke: %s %s %s %s %s\n",st->nl.name,
+    Message(M_INFO,"%s: about to invoke: %s %s %s %s %s\n",st->slip.nl.name,
            st->userv_path,st->service_user,st->service_name,addrs,nets);
 
-    /* Allocate buffer, plus space for padding. Make sure we end up
-       with the start of the packet well-aligned. */
-    /* ALIGN(st->max_start_pad,16); */
-    /* ALIGN(st->max_end_pad,16); */
-
-    st->pending_esc=False;
+    st->slip.pending_esc=False;
 
     /* Invoke userv */
     if (pipe(c_stdin)!=0) {
-       fatal_perror("userv_phase_hook: pipe(c_stdin)");
+       fatal_perror("userv_invoke_userv: pipe(c_stdin)");
     }
     if (pipe(c_stdout)!=0) {
-       fatal_perror("userv_phase_hook: pipe(c_stdout)");
+       fatal_perror("userv_invoke_userv: pipe(c_stdout)");
     }
     st->txfd=c_stdin[1];
     st->rxfd=c_stdout[0];
 
-    child=fork();
-    if (child==-1) {
-       fatal_perror("userv_phase_hook: fork()");
+    er=safe_malloc(sizeof(*r),"userv_invoke_userv: er");
+
+    er->in=c_stdin[0];
+    er->out=c_stdout[1];
+    /* The arguments are:
+       userv
+       service-user
+       service-name
+       local-addr,secnet-addr,mtu,protocol
+       route1,route2,... */
+    er->argv=safe_malloc(sizeof(*er->argv)*6,"userv_invoke_userv:argv");
+    er->argv[0]=st->userv_path;
+    er->argv[1]=st->service_user;
+    er->argv[2]=st->service_name;
+    er->argv[3]=addrs;
+    er->argv[4]=nets;
+    er->argv[5]=NULL;
+    er->path=st->userv_path;
+
+    st->pid=makesubproc(userv_entry, userv_userv_callback,
+                       er, st, st->slip.nl.name);
+    close(er->in);
+    close(er->out);
+    free(er->argv);
+    free(er);
+    free(addrs);
+    free(nets);
+    Message(M_INFO,"%s: userv-ipif pid is %d\n",st->slip.nl.name,st->pid);
+    /* Read a single character from the pipe to confirm userv-ipif is
+       running. If we get a SIGCHLD at this point then we'll get EINTR. */
+    if ((nread=read(st->rxfd,&confirm,1))!=1) {
+       if (errno==EINTR) {
+           Message(M_WARNING,"%s: read of confirmation byte was "
+                   "interrupted\n",st->slip.nl.name);
+       } else {
+           if (nread<0) {
+               fatal_perror("%s: error reading confirmation byte",
+                            st->slip.nl.name);
+           } else {
+               fatal("%s: unexpected EOF instead of confirmation byte"
+                     " - userv ipif failed?", st->slip.nl.name);
+           }
+       }
+    } else {
+       if (confirm!=SLIP_END) {
+           fatal("%s: bad confirmation byte %d from userv-ipif",
+                 st->slip.nl.name,confirm);
+       }
+    }
+}
+
+static void userv_kill_userv(struct userv *st)
+{
+    if (st->pid) {
+       kill(-st->pid,SIGTERM);
+       st->expecting_userv_exit=True;
     }
-    if (child==0) {
-       char **argv;
-
-       /* We are the child. Modify our stdin and stdout, then exec userv */
-       dup2(c_stdin[0],0);
-       dup2(c_stdout[1],1);
-       close(c_stdin[1]);
-       close(c_stdout[0]);
-
-       /* The arguments are:
-          userv
-          service-user
-          service-name
-          local-addr,secnet-addr,mtu,protocol
-          route1,route2,... */
-       argv=malloc(sizeof(*argv)*6);
-       argv[0]=st->userv_path;
-       argv[1]=st->service_user;
-       argv[2]=st->service_name;
-       argv[3]=addrs;
-       argv[4]=nets;
-       argv[5]=NULL;
-       execvp(st->userv_path,argv);
-       perror("netlink-userv-ipif: execvp");
-
-       exit(1);
+}
+
+static void userv_phase_hook(void *sst, uint32_t newphase)
+{
+    struct userv *st=sst;
+    /* We must wait until signal processing has started before forking
+       userv */
+    if (newphase==PHASE_RUN) {
+       userv_invoke_userv(st);
+       /* Register for poll() */
+       register_for_poll(st, userv_beforepoll, userv_afterpoll, 2,
+                         st->slip.nl.name);
+    }
+    if (newphase==PHASE_SHUTDOWN) {
+       userv_kill_userv(st);
     }
-    /* We are the parent... */
-          
-    /* Register for poll() */
-    register_for_poll(st, userv_beforepoll, userv_afterpoll, 2, st->nl.name);
 }
 
 static list_t *userv_apply(closure_t *self, struct cloc loc, dict_t *context,
@@ -246,9 +384,8 @@ static list_t *userv_apply(closure_t *self, struct cloc loc, dict_t *context,
     
     dict=item->data.dict;
 
-    st->netlink_to_tunnel=
-       netlink_init(&st->nl,st,loc,dict,
-                    "netlink-userv-ipif",NULL,userv_deliver_to_kernel);
+    slip_init(&st->slip,loc,dict,"netlink-userv-ipif",
+             userv_deliver_to_kernel);
 
     st->userv_path=dict_read_string(dict,"userv-path",False,"userv-netlink",
                                    loc);
@@ -259,24 +396,16 @@ static list_t *userv_apply(closure_t *self, struct cloc loc, dict_t *context,
     if (!st->userv_path) st->userv_path="userv";
     if (!st->service_user) st->service_user="root";
     if (!st->service_name) st->service_name="ipif";
-    st->buff=find_cl_if(dict,"buffer",CL_BUFFER,True,"userv-netlink",loc);
-    st->local_address=string_to_ipaddr(
-       dict_find_item(dict,"local-address", True, "netlink", loc),"netlink");
-    BUF_ALLOC(st->buff,"netlink:userv_apply");
-
     st->rxfd=-1; st->txfd=-1;
-    add_hook(PHASE_DROPPRIV,userv_phase_hook,st);
+    st->pid=0;
+    st->expecting_userv_exit=False;
+    add_hook(PHASE_RUN,userv_phase_hook,st);
+    add_hook(PHASE_SHUTDOWN,userv_phase_hook,st);
 
-    return new_closure(&st->nl.cl);
+    return new_closure(&st->slip.nl.cl);
 }
 
-init_module slip_module;
 void slip_module(dict_t *dict)
 {
     add_closure(dict,"userv-ipif",userv_apply);
-#if 0
-    /* TODO */
-    add_closure(dict,"pty-slip",ptyslip_apply);
-    add_closure(dict,"slipd",slipd_apply);
-#endif /* 0 */
 }