chiark / gitweb /
pkstream/pkstream.c: Allow multiple listening and peer addresses.
[tripe] / pkstream / pkstream.c
index c0deff73169d1f2540184fa87ab20d7a2619908b..3a5f3fde573c1a4d49b3be4e45dff8d49456cfd1 100644 (file)
@@ -45,6 +45,7 @@
 
 #include <mLib/alloc.h>
 #include <mLib/bits.h>
+#include <mLib/darray.h>
 #include <mLib/dstr.h>
 #include <mLib/fdflags.h>
 #include <mLib/mdwopt.h>
@@ -62,6 +63,9 @@ typedef union addr {
   struct sockaddr_in sin;
 } addr;
 
+DA_DECL(addr_v, addr);
+DA_DECL(str_v, const char *);
+
 typedef struct pk {
   struct pk *next;                     /* Next packet in the chain */
   octet *p, *o;                                /* Buffer start and current posn */
@@ -80,8 +84,8 @@ typedef struct pkstream {
 typedef struct connwait {
   unsigned f;                          /* Various flags */
 #define cwf_port 1u                    /*   Port is defined => listen */
-  sel_file a;                          /* Selector */
-  addr me, peer;                      /* Who I'm meant to be; who peer is */
+  sel_file *sfv;                       /* Selectors */
+  addr_v me, peer;                    /* Who I'm meant to be; who peer is */
 } connwait;
 
 /*----- Static variables --------------------------------------------------*/
@@ -261,37 +265,56 @@ static void doaccept(int fd_s, unsigned mode, void *p)
   int fd;
   addr a;
   socklen_t sz = sizeof(a);
+  size_t i, n;
 
   if ((fd = accept(fd_s, &a.sa, &sz)) < 0) {
     if (errno == EAGAIN || errno == EWOULDBLOCK || errno == EINTR) return;
     moan("couldn't accept incoming connection: %s", strerror(errno));
     return;
   }
-  if (cw.peer.sin.sin_addr.s_addr != INADDR_ANY && !addreq(&a, &cw.peer)) {
-    moan("rejecting connection from %s", addrstr(&a));
-    close(fd); return;
-  }
+  n = DA_LEN(&cw.peer);
+  if (!n) goto match;
+  for (i = 0; i < n; i++) if (addreq(&a, &DA(&cw.peer)[i])) goto match;
+  moan("rejecting connection from %s", addrstr(&a));
+  close(fd); return;
+match:
   if (nonblockify(fd) || cloexec(fd)) {
     moan("couldn't accept incoming connection: %s", strerror(errno));
     close(fd); return;
   }
   dofwd(fd, fd);
-  close(fd_s);
-  sel_rmfile(&cw.a);
+  n = DA_LEN(&cw.me);
+  for (i = 0; i < n; i++) { close(cw.sfv[i].fd); sel_rmfile(&cw.sfv[i]); }
 }
 
-static void dolisten(void)
+static void dolisten1(const addr *a, sel_file *sf)
 {
   int fd;
   int opt = 1;
 
-  if ((fd = socket(cw.me.sa.sa_family, SOCK_STREAM, 0)) < 0 ||
+  if ((fd = socket(a->sa.sa_family, SOCK_STREAM, IPPROTO_TCP)) < 0 ||
       setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt)) ||
-      bind(fd, &cw.me.sa, addrsz(&cw.me)) ||
+      bind(fd, &a->sa, addrsz(a)) ||
       listen(fd, 1) || nonblockify(fd) || cloexec(fd))
     die(1, "couldn't set up listening socket: %s", strerror(errno));
-  sel_initfile(&sel, &cw.a, fd, SEL_READ, doaccept, 0);
-  sel_addfile(&cw.a);
+  sel_initfile(&sel, sf, fd, SEL_READ, doaccept, 0);
+  sel_addfile(sf);
+}
+
+static void dolisten(void)
+{
+  size_t i, n;
+
+  n = DA_LEN(&cw.me);
+  for (i = 0; i < n; i++)
+    dolisten1(&DA(&cw.me)[i], &cw.sfv[i]);
+}
+
+static void pushaddr(addr_v *av, const addr *a)
+{
+  DA_ENSURE(av, 1);
+  DA(av)[DA_LEN(av)] = *a;
+  DA_EXTEND(av, 1);
 }
 
 #define paf_parse 1u
@@ -362,12 +385,14 @@ stdout; though it can use TCP sockets instead.\n\
 int main(int argc, char *argv[])
 {
   unsigned f = 0;
-  const char *bindhost = 0, *bindsvc = 0, *peerhost = 0;
+  str_v bindhosts = DA_INIT, peerhosts = DA_INIT;
+  const char *bindsvc = 0;
   addr bindaddr;
   const char *connhost = 0;
   addr tmpaddr;
   int fd = -1;
   int len = 65536;
+  size_t i, n;
 
 #define f_bogus 1u
 
@@ -396,39 +421,59 @@ int main(int argc, char *argv[])
       case 'v': version(stdout); exit(0);
       case 'u': usage(stdout); exit(0);
       case 'l': bindsvc = optarg; break;
-      case 'p': peerhost = optarg; break;
-      case 'b': bindhost = optarg; break;
+      case 'p': DA_PUSH(&peerhosts, optarg); break;
+      case 'b': DA_PUSH(&bindhosts, optarg); break;
       case 'c': connhost = optarg; break;
       default: f |= f_bogus; break;
     }
   }
   if (optind + 2 != argc || (f&f_bogus)) { usage(stderr); exit(1); }
 
-  if (bindhost && !bindsvc && !connhost)
+  if (DA_LEN(&bindhosts) && !bindsvc && !connhost)
     die(1, "bind addr only makes sense when listening or connecting");
-  if (peerhost && !bindsvc)
+  if (DA_LEN(&peerhosts) && !bindsvc)
     die(1, "peer addr only makes sense when listening");
   if (bindsvc && connhost)
     die(1, "can't listen and connect");
 
-  if (bindhost || bindsvc) {
-    initaddr(&bindaddr);
-    if (!bindsvc) parseaddr(bindhost, 0, 0, &bindaddr);
-    else {
-      initaddr(&cw.me);
-      parseaddr(bindhost, bindsvc, 0, &cw.me);
+  DA_CREATE(&cw.me); DA_CREATE(&cw.peer);
+
+  n = DA_LEN(&bindhosts);
+  if (n || bindsvc) {
+    if (!n) {
+      initaddr(&tmpaddr);
+      parseaddr(0, bindsvc, 0, &tmpaddr);
+      pushaddr(&cw.me, &tmpaddr);
+    } else if (!bindsvc) {
+      if (n != 1) die(1, "can only bind to one address as client");
+      initaddr(&bindaddr);
+      parseaddr(DA(&bindhosts)[0], 0, 0, &bindaddr);
+    } else for (i = 0; i < n; i++) {
+      initaddr(&tmpaddr);
+      parseaddr(DA(&bindhosts)[i], bindsvc, 0, &tmpaddr);
+      pushaddr(&cw.me, &tmpaddr);
+    }
+    if (bindsvc) {
       cw.f |= cwf_port;
+      n = DA_LEN(&cw.me);
+      cw.sfv = xmalloc(n*sizeof(*cw.sfv));
     }
   }
 
-  initaddr(&cw.peer);
-  if (peerhost) parseaddr(peerhost, 0, 0, &cw.peer);
+  n = DA_LEN(&peerhosts);
+  if (n) {
+    for (i = 0; i < n; i++) {
+      initaddr(&tmpaddr);
+      parseaddr(DA(&peerhosts)[0], 0, 0, &tmpaddr);
+      pushaddr(&cw.peer, &tmpaddr);
+    }
+  }
 
   if (connhost) {
     initaddr(&tmpaddr);
     parseaddr(connhost, 0, paf_parse, &tmpaddr);
     if ((fd = socket(tmpaddr.sa.sa_family, SOCK_STREAM, IPPROTO_TCP)) < 0 ||
-       (bindhost &&
+       (DA_LEN(&bindhosts) &&
         bind(fd, &bindaddr.sa, addrsz(&bindaddr))) ||
        connect(fd, &tmpaddr.sa, addrsz(&tmpaddr)))
       die(1, "couldn't connect to TCP server: %s", strerror(errno));