chiark / gitweb /
svc/conntrack.in: Make an `InetAddress' class to do address wrangling.
authorMark Wooding <mdw@distorted.org.uk>
Thu, 28 Sep 2017 18:54:32 +0000 (19:54 +0100)
committerMark Wooding <mdw@distorted.org.uk>
Thu, 28 Jun 2018 23:29:23 +0000 (00:29 +0100)
The name is a little misleading: it can also represent a network, but
separating the two turns out to be a little tedious, so I don't bother.

This means that the configuration now actually contains (PEER,
TEST-ADDRESS, LOCAL-NET) triples, rather than keeping the address and
mask portions of the LOCAL-NET separate.

This is rather an invasive change.  Sorry.

svc/conntrack.in

index 2dc380b953800f8fdea3538fb2dd9626bd81a78c..83b2993c7c347495ea340b4cab02606a7b8c76a7 100644 (file)
@@ -97,26 +97,56 @@ def toposort(cmp, things):
 ###--------------------------------------------------------------------------
 ### Address manipulation.
 
-def parse_address(addrstr):
-  return unpack('>L', S.inet_aton(addrstr))[0]
+class InetAddress (object):
+  def __init__(me, addrstr, maskstr = None):
+    me.addr = me._addrstr_to_int(addrstr)
+    if maskstr is None:
+      me.mask = -1
+    elif maskstr.isdigit():
+      me.mask = (1 << 32) - (1 << 32 - int(maskstr))
+    else:
+      me.mask = me._addrstr_to_int(maskstr)
+    if me.addr&~me.mask:
+      raise ValueError('network contains bits set beyond mask')
+  def _addrstr_to_int(me, addrstr):
+    return unpack('>L', S.inet_aton(addrstr))[0]
+  def _int_to_addrstr(me, n):
+    return S.inet_ntoa(pack('>L', n))
+  def sockaddr(me, port = 0):
+    if me.mask != -1: raise ValueError('not a simple address')
+    return me._int_to_addrstr(me.addr), port
+  def __str__(me):
+    addrstr = me._int_to_addrstr(me.addr)
+    if me.mask == -1:
+      return addrstr
+    else:
+      inv = me.mask ^ ((1 << 32) - 1)
+      if (inv&(inv + 1)) == 0:
+        return '%s/%d' % (addrstr, 32 - inv.bit_length())
+      else:
+        return '%s/%s' % (addrstr, me._int_to_addrstr(me.mask))
+  def withinp(me, net):
+    if (me.mask&net.mask) != net.mask: return False
+    if (me.addr ^ net.addr)&net.mask: return False
+    return True
+  def eq(me, other):
+    if me.mask != other.mask: return False
+    if me.addr != other.addr: return False
+    return True
+  @classmethod
+  def from_sockaddr(cls, sa):
+    addr, port = (lambda a, p: (a, p))(*sa)
+    return cls(addr), port
+
+def parse_address(addrstr, maskstr = None):
+  return InetAddress(addrstr, maskstr)
 
 def parse_net(netstr):
   try: sl = netstr.index('/')
   except ValueError: raise ValueError('missing mask')
-  addr = parse_address(netstr[:sl])
-  if netstr[sl + 1:].isdigit():
-    n = int(netstr[sl + 1:], 10)
-    mask = (1 << 32) - (1 << 32 - n)
-  else:
-    mask = parse_address(netstr[sl + 1:])
-  if addr&~mask: raise ValueError('network contains bits set beyond mask')
-  return addr, mask
+  return parse_address(netstr[:sl], netstr[sl + 1:])
 
-def straddr(a): return a is None and '#<none>' or S.inet_ntoa(pack('>L', a))
-def strmask(m):
-  for i in xrange(33):
-    if m == 0xffffffff ^ ((1 << (32 - i)) - 1): return str(i)
-  return straddr(m)
+def straddr(a): return a is None and '#<none>' or str(a)
 
 ###--------------------------------------------------------------------------
 ### Parse the configuration file.
@@ -126,18 +156,19 @@ def strmask(m):
 ## this service are largely going to be satellite notes, I don't think
 ## scalability's going to be a problem.
 
+TESTADDR = InetAddress('1.2.3.4')
+
 class Config (object):
   """
   Represents a configuration file.
 
   The most interesting thing is probably the `groups' slot, which stores a
   list of pairs (NAME, PATTERNS); the NAME is a string, and the PATTERNS a
-  list of (TAG, PEER, ADDR, MASK) triples.  The implication is that there
-  should be precisely one peer with a name matching NAME-*, and that it
-  should be NAME-TAG, where (TAG, PEER, ADDR, MASK) is the first triple such
-  that the host's primary IP address (if PEER is None -- or the IP address it
-  would use for communicating with PEER) is within the network defined by
-  ADDR/MASK.
+  list of (TAG, PEER, NET) triples.  The implication is that there should be
+  precisely one peer from the set, and that it should be named TAG, where
+  (TAG, PEER, NET) is the first triple such that the host's primary IP
+  address (if PEER is None -- or the IP address it would use for
+  communicating with PEER) is within the NET.
   """
 
   def __init__(me, file):
@@ -171,10 +202,9 @@ class Config (object):
     ## actually in use.  Note that we never send packets to the test address;
     ## we just use it to discover routing information.
     if cp.has_option('DEFAULT', 'test-addr'):
-      testaddr = cp.get('DEFAULT', 'test-addr')
-      S.inet_aton(testaddr)
+      testaddr = InetAddress(cp.get('DEFAULT', 'test-addr'))
     else:
-      testaddr = '1.2.3.4'
+      testaddr = TESTADDR
 
     ## Scan the configuration file and build the groups structure.
     groups = []
@@ -188,20 +218,21 @@ class Config (object):
           peer = None
           net = spec[0]
         else:
-          peer, net = spec
+          peer = InetAddress(spec[0])
+          net = spec[1]
 
         ## Syntax of a net is ADDRESS/MASK, where ADDRESS is a dotted-quad,
         ## and MASK is either a dotted-quad or a single integer N indicating
         ## a mask with N leading ones followed by trailing zeroes.
-        addr, mask = parse_net(net)
-        pats.append((tag, peer, addr, mask))
+        net = parse_net(net)
+        pats.append((tag, peer, net))
 
       ## Annoyingly, RawConfigParser doesn't preserve the order of options.
       ## In order to make things vaguely sane, we topologically sort the
       ## patterns so that more specific patterns are checked first.
-      pats = list(toposort(lambda (t, p, a, m), (tt, pp, aa, mm): \
+      pats = list(toposort(lambda (t, p, n), (tt, pp, nn): \
                              (p and not pp) or \
-                             (p == pp and m == (m | mm) and aa == (a & mm)),
+                             (p == pp and n.withinp(nn)),
                            pats))
       groups.append((sec, pats))
 
@@ -224,10 +255,10 @@ def cmd_showgroup(g):
       break
   else:
     raise T.TripeJobError('unknown-group', g)
-  for t, p, a, m in pats:
+  for t, p, n in pats:
     T.svcinfo('peer', t,
-              'target', p or '(default)',
-              'net', '%s/%s' % (straddr(a), strmask(m)))
+              'target', p and str(p) or '(default)',
+              'net', str(n))
 
 ###--------------------------------------------------------------------------
 ### Responding to a network up/down event.
@@ -239,10 +270,9 @@ def localaddr(peer):
   sk = S.socket(S.AF_INET, S.SOCK_DGRAM)
   try:
     try:
-      sk.connect((peer, 1))
-      addr, _ = sk.getsockname()
-      addr = parse_address(addr)
-      return addr
+      sk.connect(peer.sockaddr(1))
+      addr = sk.getsockname()
+      return InetAddress.from_sockaddr(addr)[0]
     except S.error:
       return None
   finally:
@@ -317,16 +347,16 @@ def kickpeers():
       ip = None
       map = {}
       want = None
-      for t, p, a, m in pp:
+      for t, p, n in pp:
         if p is None or not upness:
           ipq = addr
         else:
           ipq = localaddr(p)
         if T._debug:
-          info = 'peer=%s; target=%s; net=%s/%s; local=%s' % (
-            t, p or '(default)', straddr(a), strmask(m), straddr(ipq))
+          info = 'peer=%s; target=%s; net=%s; local=%s' % (
+            t, p or '(default)', n, straddr(ipq))
         if upness and ip is None and \
-              ipq is not None and (ipq & m) == a:
+              ipq is not None and ipq.withinp(n):
           if T._debug: print '#     %s: SELECTED' % info
           map[t] = 'up'
           select.append('%s=%s' % (g, t))