chiark / gitweb /
pathmtu/pathmtu.c: Replace explicit `sockaddr_in' structures with union.
[tripe] / pathmtu / pathmtu.c
index f42602e03bce709391ff7765b5fa7fab8c640d01..dbcba593ca9e5b967118badbf019301da2d8660d 100644 (file)
@@ -105,6 +105,33 @@ static double s2f(const char *s, const char *what)
 static void f2tv(struct timeval *tv, double t)
   { tv->tv_sec = t; tv->tv_usec = (t - tv->tv_sec)*MILLION; }
 
+union addr {
+  struct sockaddr sa;
+  struct sockaddr_in sin;
+};
+
+/* Return the size of a socket address. */
+static size_t addrsz(const union addr *a)
+{
+  switch (a->sa.sa_family) {
+    case AF_INET: return (sizeof(a->sin));
+    default: abort();
+  }
+}
+
+/* Compare two addresses.  Maybe compare the port numbers too. */
+#define AEF_PORT 1u
+static int addreq(const union addr *a, const union addr *b, unsigned f)
+{
+  switch (a->sa.sa_family) {
+    case AF_INET:
+      return (a->sin.sin_addr.s_addr == b->sin.sin_addr.s_addr &&
+             (!(f&AEF_PORT) || a->sin.sin_port == b->sin.sin_port));
+    default:
+      abort();
+  }
+}
+
 /*----- Main algorithm skeleton -------------------------------------------*/
 
 struct param {
@@ -115,7 +142,7 @@ struct param {
   double timeout;                      /* Retransmission timeout */
   int seqoff;                          /* Offset to write sequence number */
   const struct probe_ops *pops;                /* Probe algorithm description */
-  struct sockaddr_in sin;              /* Destination address */
+  union addr a;                                /* Destination address */
 };
 
 struct probestate {
@@ -222,8 +249,9 @@ static int pathmtu(const struct param *pp)
   /* Build and connect a UDP socket.  We'll need this to know the local port
    * number to use if nothing else.  Set other stuff up.
    */
-  if ((sk = socket(PF_INET, SOCK_DGRAM, 0)) < 0) goto fail_0;
-  if (connect(sk, (struct sockaddr *)&pp->sin, sizeof(pp->sin))) goto fail_1;
+  if ((sk = socket(pp->a.sa.sa_family, SOCK_DGRAM, IPPROTO_UDP)) < 0)
+    goto fail_0;
+  if (connect(sk, &pp->a.sa, addrsz(&pp->a))) goto fail_1;
   st = xmalloc(pp->pops->statesz);
   if ((mtu = pp->pops->setup(st, sk, pp)) < 0) goto fail_2;
   ps.pp = pp; ps.q = rand() & 0xffff;
@@ -390,7 +418,7 @@ struct phdr {
 };
 
 struct raw_state {
-  struct sockaddr_in me, sin;
+  union addr me, a;
   int sk, rawicmp, rawudp;
   unsigned q;
 };
@@ -403,17 +431,25 @@ static int raw_setup(void *stv, int sk, const struct param *pp)
   struct ifaddrs *ifa, *ifaa, *ifap;
   struct ifreq ifr;
 
-  /* If we couldn't acquire raw sockets, we fail here. */
-  if (rawerr) { errno = rawerr; goto fail_0; }
-  st->rawicmp = rawicmp; st->rawudp = rawudp; st->sk = sk;
+  /* Check that the address is OK, and that we have the necessary raw
+   * sockets.
+   */
+  switch (pp->a.sa.sa_family) {
+    case AF_INET:
+      if (rawerr) { errno = rawerr; goto fail_0; }
+      st->rawicmp = rawicmp; st->rawudp = rawudp; st->sk = sk;
+      break;
+    default:
+      errno = EPFNOSUPPORT; goto fail_0;
+  }
 
   /* Initialize the sequence number. */
   st->q = rand() & 0xffff;
 
   /* Snaffle the local and remote address and port number. */
-  st->sin = pp->sin;
+  st->a = pp->a;
   sz = sizeof(st->me);
-  if (getsockname(sk, (struct sockaddr *)&st->me, &sz))
+  if (getsockname(sk, &st->me.sa, &sz))
     goto fail_0;
 
   /* There isn't a portable way to force the DF flag onto a packet through
@@ -434,10 +470,9 @@ static int raw_setup(void *stv, int sk, const struct param *pp)
   for (i = 0; i < 2; i++) {
     for (ifap = 0, ifa = ifaa; ifa; ifa = ifa->ifa_next) {
       if (!(ifa->ifa_flags & IFF_UP) || !ifa->ifa_addr ||
-         ifa->ifa_addr->sa_family != AF_INET ||
+         ifa->ifa_addr->sa_family != st->me.sa.sa_family ||
          (i == 0 &&
-          ((struct sockaddr_in *)ifa->ifa_addr)->sin_addr.s_addr !=
-               st->me.sin_addr.s_addr) ||
+          !addreq((union addr *)ifa->ifa_addr, &st->me, 0)) ||
          (i == 1 && ifap && strcmp(ifap->ifa_name, ifa->ifa_name) == 0) ||
          strlen(ifa->ifa_name) >= sizeof(ifr.ifr_name))
        continue;
@@ -485,13 +520,13 @@ static int raw_xmit(void *stv, int mtu)
   ip->ip_ttl = 64;
   ip->ip_p = IPPROTO_UDP;
   ip->ip_sum = 0;
-  ip->ip_src = st->me.sin_addr;
-  ip->ip_dst = st->sin.sin_addr;
+  ip->ip_src = st->me.sin.sin_addr;
+  ip->ip_dst = st->a.sin.sin_addr;
 
   /* Build a UDP packet in the output buffer. */
   udp = (struct udphdr *)(ip + 1);
-  udp->uh_sport = st->me.sin_port;
-  udp->uh_dport = st->sin.sin_port;
+  udp->uh_sport = st->me.sin.sin_port;
+  udp->uh_dport = st->a.sin.sin_port;
   udp->uh_ulen = htons(mtu - sizeof(*ip));
   udp->uh_sum = 0;
 
@@ -513,8 +548,7 @@ static int raw_xmit(void *stv, int mtu)
   /* Send the whole thing off.  If we're too big for the interface then we
    * might need to trim immediately.
    */
-  if (sendto(st->rawudp, b, mtu, 0,
-            (struct sockaddr *)&st->sin, sizeof(st->sin)) < 0) {
+  if (sendto(st->rawudp, b, mtu, 0, &st->a.sa, addrsz(&st->a)) < 0) {
     if (errno == EMSGSIZE) return (RC_LOWER);
     else goto fail_0;
   }
@@ -555,14 +589,14 @@ static int raw_selproc(void *stv, fd_set *fd_in, struct probestate *ps)
     if (n < sizeof(*ip) ||
        ip->ip_p != IPPROTO_UDP || ip->ip_hl != sizeof(*ip)/4 ||
        ip->ip_id != htons(st->q) ||
-       ip->ip_src.s_addr != st->me.sin_addr.s_addr ||
-       ip->ip_dst.s_addr != st->sin.sin_addr.s_addr)
+       ip->ip_src.s_addr != st->me.sin.sin_addr.s_addr ||
+       ip->ip_dst.s_addr != st->a.sin.sin_addr.s_addr)
       goto skip_icmp;
     n -= sizeof(*ip);
 
     udp = (struct udphdr *)(ip + 1);
-    if (n < sizeof(udp) || udp->uh_sport != st->me.sin_port ||
-       udp->uh_dport != st->sin.sin_port)
+    if (n < sizeof(udp) || udp->uh_sport != st->me.sin.sin_port ||
+       udp->uh_dport != st->a.sin.sin_port)
       goto skip_icmp;
     n -= sizeof(*udp);
 
@@ -618,6 +652,12 @@ static int linux_setup(void *stv, int sk, const struct param *pp)
   int i, mtu;
   socklen_t sz;
 
+  /* Check that the address is OK. */
+  switch (pp->a.sa.sa_family) {
+    case AF_INET: break;
+    default: errno = EPFNOSUPPORT; return (-1);
+  }
+
   /* Snaffle the UDP socket. */
   st->sk = sk;
 
@@ -756,7 +796,7 @@ int main(int argc, char *argv[])
 
   ego(argv[0]);
   fillbuffer(buf, sizeof(buf));
-  pp.sin.sin_port = htons(7);
+  pp.a.sin.sin_port = htons(7);
 
   for (;;) {
     static const struct option opts[] = {
@@ -817,21 +857,21 @@ int main(int argc, char *argv[])
     die(EXIT_FAILURE, "unknown host `%s': %s", *argv, hstrerror(h_errno));
   if (h->h_addrtype != AF_INET)
     die(EXIT_FAILURE, "unsupported address family for host `%s'", *argv);
-  memcpy(&pp.sin.sin_addr, h->h_addr, sizeof(struct in_addr));
+  memcpy(&pp.a.sin.sin_addr, h->h_addr, sizeof(struct in_addr));
   argv++; argc--;
 
   if (*argv) {
     errno = 0;
     u = strtoul(*argv, &q, 0);
     if (!errno && !*q)
-      pp.sin.sin_port = htons(u);
+      pp.a.sin.sin_port = htons(u);
     else if ((s = getservbyname(*argv, "udp")) == 0)
       die(EXIT_FAILURE, "unknown UDP service `%s'", *argv);
     else
-      pp.sin.sin_port = s->s_port;
+      pp.a.sin.sin_port = s->s_port;
   }
 
-  pp.sin.sin_family = AF_INET;
+  pp.a.sin.sin_family = AF_INET;
   i = pathmtu(&pp);
   if (i < 0)
     die(EXIT_FAILURE, "failed to discover MTU: %s", strerror(errno));