chiark / gitweb /
pkstream/pkstream.c: Rearrange socket setup, particularly `parseaddr'.
authorMark Wooding <mdw@distorted.org.uk>
Wed, 27 Sep 2017 19:49:00 +0000 (20:49 +0100)
committerMark Wooding <mdw@distorted.org.uk>
Sat, 16 Jun 2018 18:14:10 +0000 (19:14 +0100)
  * Have `parseaddr' fill in a socket address structure directly.

  * Change the interface to pass in either separate host and
    service (does this remind you of anything?) names, or a single
    combined string to be parsed apart, as indicated by a new flag
    `paf_parse'.

  * Have `main' keep track of the various host and service name strings
    and then sort everything out at the end, rather than exercising the
    resolver during option parsing.  Take advantage of this by
    diagnosing incompatible option combinations.

  * To make this work, upgrade `cw.peer' to be full socket address.

  * Factor out socket-address initialization, and initialize the
    structures on demand rather than in advance.

pkstream/pkstream.c

index 88ff381a8b8187f948873598b973eb5a74b700e0..f1ddf7791dfa5e486ff90f13797a7877b3b095c0 100644 (file)
@@ -74,8 +74,7 @@ typedef struct pkstream {
 
 typedef struct connwait {
   sel_file a;                          /* Selector */
-  struct sockaddr_in me;               /* Who I'm meant to be */
-  struct in_addr peer;                 /* Who my peer is */
+  struct sockaddr_in me, peer;        /* Who I'm meant to be; who peer is */
 } connwait;
 
 /*----- Static variables --------------------------------------------------*/
@@ -93,6 +92,13 @@ static int nonblockify(int fd)
 static int cloexec(int fd)
   { return (fdflags(fd, 0, 0, FD_CLOEXEC, FD_CLOEXEC)); }
 
+static void initaddr(struct sockaddr_in *sin)
+{
+  sin->sin_family = AF_INET;
+  sin->sin_addr.s_addr = INADDR_ANY;
+  sin->sin_port = 0;
+}
+
 static void dolisten(void);
 
 static void doclose(pkstream *p)
@@ -225,8 +231,8 @@ static void doaccept(int fd_s, unsigned mode, void *p)
     moan("couldn't accept incoming connection: %s", strerror(errno));
     return;
   }
-  if (cw.peer.s_addr != INADDR_ANY &&
-      cw.peer.s_addr != sin.sin_addr.s_addr) {
+  if (cw.peer.sin_addr.s_addr != INADDR_ANY &&
+      cw.peer.sin_addr.s_addr != sin.sin_addr.s_addr) {
     moan("rejecting connection from %s", inet_ntoa(sin.sin_addr));
     close(fd); return;
   }
@@ -253,33 +259,38 @@ static void dolisten(void)
   sel_addfile(&cw.a);
 }
 
-static void parseaddr(const char *pp, struct in_addr *a, unsigned short *pt)
+#define paf_parse 1u
+static void parseaddr(const char *host, const char *svc, unsigned f,
+                     struct sockaddr_in *sin)
 {
-  char *p = xstrdup(pp);
-  char *q = 0;
+  char *alloc = 0, *sep;
   struct hostent *h;
   struct servent *s;
   char *qq;
   unsigned long n;
 
-  if (!pt);
-  else if (a) q = p;
-  else {
-    strtok(p, ":");
-    q = strtok(0, "");
-    if (!q) die(1, "missing port number in address `%s'", p);
+  if (f&paf_parse) {
+    alloc = xstrdup(host);
+    if ((sep = strchr(alloc, ':')) == 0)
+      die(1, "missing port number in address `%s'", host);
+    host = alloc; *sep = 0; svc = sep + 1;
   }
 
-  if (a) {
-    if ((h = gethostbyname(p)) == 0) die(1, "unknown host `%s'", p);
-    memcpy(a, h->h_addr, sizeof(*a));
+  if (host) {
+    if ((h = gethostbyname(host)) == 0) die(1, "unknown host `%s'", host);
+    memcpy(&sin->sin_addr, h->h_addr, sizeof(sin->sin_addr));
   }
 
-  if (pt) {
-    if ((n = strtoul(q, &qq, 0)) > 0 && !*qq && n <= 0xffff) *pt = htons(n);
-    else if ((s = getservbyname(q, "tcp")) != 0) *pt = s->s_port;
-    else die(1, "bad port number `%s'", q);
+  if (svc) {
+    if ((n = strtoul(svc, &qq, 0)) > 0 && !*qq && n <= 0xffff)
+      sin->sin_port = htons(n);
+    else if ((s = getservbyname(svc, "tcp")) != 0)
+      sin->sin_port = s->s_port;
+    else
+      die(1, "bad service name/number `%s'", svc);
   }
+
+  xfree(alloc);
 }
 
 static void usage(FILE *fp)
@@ -317,24 +328,16 @@ stdout; though it can use TCP sockets instead.\n\
 int main(int argc, char *argv[])
 {
   unsigned f = 0;
-  unsigned short pt;
-  struct sockaddr_in connaddr, bindaddr;
-  struct sockaddr_in udp_me, udp_peer;
-  int fd;
+  const char *bindhost = 0, *bindsvc = 0, *peerhost = 0;
+  struct sockaddr_in bindaddr;
+  const char *connhost = 0;
+  struct sockaddr_in tmpaddr;
+  int fd = -1;
   int len = 65536;
 
 #define f_bogus 1u
 
   ego(argv[0]);
-  bindaddr.sin_family = AF_INET;
-  bindaddr.sin_addr.s_addr = INADDR_ANY;
-  bindaddr.sin_port = 0;
-  connaddr.sin_family = AF_INET;
-  connaddr.sin_addr.s_addr = INADDR_ANY;
-  cw.me.sin_family = AF_INET;
-  cw.me.sin_addr.s_addr = INADDR_ANY;
-  cw.me.sin_port = 0;
-  cw.peer.s_addr = INADDR_ANY;
   sel_init(&sel);
   for (;;) {
     static struct option opt[] = {
@@ -356,47 +359,60 @@ int main(int argc, char *argv[])
       case 'h': help(stdout); exit(0);
       case 'v': version(stdout); exit(0);
       case 'u': usage(stdout); exit(0);
-      case 'l': parseaddr(optarg, 0, &pt); cw.me.sin_port = pt; break;
-      case 'p': parseaddr(optarg, &cw.peer, 0); break;
-      case 'b':
-       parseaddr(optarg, &bindaddr.sin_addr, 0);
-       cw.me.sin_addr = bindaddr.sin_addr;
-       break;
-      case 'c':
-       parseaddr(optarg, &connaddr.sin_addr, 0);
-       connaddr.sin_port = pt;
-       break;
+      case 'l': bindsvc = optarg; break;
+      case 'p': peerhost = optarg; break;
+      case 'b': bindhost = optarg; break;
+      case 'c': connhost = optarg; break;
       default: f |= f_bogus; break;
     }
   }
   if (optind + 2 != argc || (f&f_bogus)) { usage(stderr); exit(1); }
 
-  udp_me.sin_family = udp_peer.sin_family = AF_INET;
-  parseaddr(argv[optind], &udp_me.sin_addr, &pt);
-  udp_me.sin_port = pt;
-  parseaddr(argv[optind + 1], &udp_peer.sin_addr, &pt);
-  udp_peer.sin_port = pt;
+  if (bindhost && !bindsvc && !connhost)
+    die(1, "bind addr only makes sense when listening or connecting");
+  if (peerhost && !bindsvc)
+    die(1, "peer addr only makes sense when listening");
+  if (bindsvc && connhost)
+    die(1, "can't listen and connect");
+
+  initaddr(&cw.me);
+  if (bindhost || bindsvc) {
+    initaddr(&bindaddr);
+    if (!bindsvc) parseaddr(bindhost, 0, 0, &bindaddr);
+    else parseaddr(bindhost, bindsvc, 0, &cw.me);
+  }
+
+  initaddr(&cw.peer);
+  if (peerhost) parseaddr(peerhost, 0, 0, &cw.peer);
 
-  if ((fd_udp = socket(PF_INET, SOCK_DGRAM, 0)) < 0 ||
-      bind(fd_udp, (struct sockaddr *)&udp_me, sizeof(udp_me)) ||
-      connect(fd_udp, (struct sockaddr *)&udp_peer, sizeof(udp_peer)) ||
+  if (connhost) {
+    initaddr(&tmpaddr);
+    parseaddr(connhost, 0, paf_parse, &tmpaddr);
+    if ((fd = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP)) < 0 ||
+       (bindhost &&
+        bind(fd, (struct sockaddr *)&bindaddr, sizeof(bindaddr))) ||
+       connect(fd, (struct sockaddr *)&tmpaddr, sizeof(tmpaddr)))
+      die(1, "couldn't connect to TCP server: %s", strerror(errno));
+    if (nonblockify(fd) || cloexec(fd))
+      die(1, "couldn't connect to TCP server: %s", strerror(errno));
+  }
+
+  initaddr(&tmpaddr);
+  parseaddr(argv[optind], 0, paf_parse, &tmpaddr);
+  if ((fd_udp = socket(AF_INET, SOCK_DGRAM, IPPROTO_UDP)) < 0 ||
+      nonblockify(fd_udp) || cloexec(fd_udp) ||
       setsockopt(fd_udp, SOL_SOCKET, SO_RCVBUF, &len, sizeof(len)) ||
       setsockopt(fd_udp, SOL_SOCKET, SO_SNDBUF, &len, sizeof(len)) ||
-      nonblockify(fd_udp) || cloexec(fd_udp))
+      bind(fd_udp, (struct sockaddr *)&tmpaddr, sizeof(tmpaddr)))
+    die(1, "couldn't set up UDP socket: %s", strerror(errno));
+  initaddr(&tmpaddr);
+  parseaddr(argv[optind + 1], 0, paf_parse, &tmpaddr);
+  if (connect(fd_udp, (struct sockaddr *)&tmpaddr, sizeof(tmpaddr)))
     die(1, "couldn't set up UDP socket: %s", strerror(errno));
 
-  if (cw.me.sin_port != 0)
-    dolisten();
-  else if (connaddr.sin_addr.s_addr == INADDR_ANY)
-    dofwd(STDIN_FILENO, STDOUT_FILENO);
-  else {
-    if ((fd = socket(PF_INET, SOCK_STREAM, 0)) < 0 ||
-       bind(fd, (struct sockaddr *)&bindaddr, sizeof(bindaddr)) ||
-       connect(fd, (struct sockaddr *)&connaddr, sizeof(connaddr)) ||
-       nonblockify(fd) || cloexec(fd))
-      die(1, "couldn't connect to TCP server: %s", strerror(errno));
-    dofwd(fd, fd);
-  }
+  if (bindsvc) dolisten();
+  else if (connhost) dofwd(fd, fd);
+  else dofwd(STDIN_FILENO, STDOUT_FILENO);
 
   for (;;) {
     if (sel_select(&sel) && errno != EINTR)