chiark / gitweb /
Merge branch '1.0.0pre19.x'
[tripe] / svc / conntrack.in
index 83b2993c7c347495ea340b4cab02606a7b8c76a7..28e4b0b3321bedad801ca71fda6a7fe11bc4c1cc 100644 (file)
@@ -36,11 +36,13 @@ import socket as S
 import mLib as M
 import tripe as T
 import dbus as D
 import mLib as M
 import tripe as T
 import dbus as D
+import re as RX
 for i in ['mainloop', 'mainloop.glib']:
   __import__('dbus.%s' % i)
 try: from gi.repository import GLib as G
 except ImportError: import gobject as G
 from struct import pack, unpack
 for i in ['mainloop', 'mainloop.glib']:
   __import__('dbus.%s' % i)
 try: from gi.repository import GLib as G
 except ImportError: import gobject as G
 from struct import pack, unpack
+from cStringIO import StringIO
 
 SM = T.svcmgr
 ##__import__('rmcr').__debug = True
 
 SM = T.svcmgr
 ##__import__('rmcr').__debug = True
@@ -53,93 +55,132 @@ class struct (object):
   def __init__(me, **kw):
     me.__dict__.update(kw)
 
   def __init__(me, **kw):
     me.__dict__.update(kw)
 
-def toposort(cmp, things):
-  """
-  Generate the THINGS in an order consistent with a given partial order.
-
-  The function CMP(X, Y) should return true if X must precede Y, and false if
-  it doesn't care.  If X and Y are equal then it should return false.
+def loadb(s):
+  n = 0
+  for ch in s: n = 256*n + ord(ch)
+  return n
 
 
-  The THINGS may be any finite iterable; it is converted to a list
-  internally.
-  """
-
-  ## Make sure we can index the THINGS, and prepare an ordering table.
-  ## What's going on?  The THINGS might not have a helpful equality
-  ## predicate, so it's easier to work with indices.  The ordering table will
-  ## remember which THINGS (by index) are considered greater than other
-  ## things.
-  things = list(things)
-  n = len(things)
-  order = [{} for i in xrange(n)]
-  rorder = [{} for i in xrange(n)]
-  for i in xrange(n):
-    for j in xrange(n):
-      if i != j and cmp(things[i], things[j]):
-        order[j][i] = True
-        rorder[i][j] = True
-
-  ## Now we can do the sort.
-  out = []
-  while True:
-    done = True
-    for i in xrange(n):
-      if order[i] is not None:
-        done = False
-        if len(order[i]) == 0:
-          for j in rorder[i]:
-            del order[j][i]
-          yield things[i]
-          order[i] = None
-    if done:
-      break
+def storeb(n, wd = None):
+  if wd is None: wd = n.bit_length()
+  s = StringIO()
+  for i in xrange((wd - 1)&-8, -8, -8): s.write(chr((n >> i)&0xff))
+  return s.getvalue()
 
 ###--------------------------------------------------------------------------
 ### Address manipulation.
 
 ###--------------------------------------------------------------------------
 ### Address manipulation.
+###
+### I think this is the most demanding application, in terms of address
+### hacking, in the entire TrIPE suite.  At least we don't have to do it in
+### C.
 
 
-class InetAddress (object):
+class BaseAddress (object):
   def __init__(me, addrstr, maskstr = None):
   def __init__(me, addrstr, maskstr = None):
-    me.addr = me._addrstr_to_int(addrstr)
+    me._setaddr(addrstr)
     if maskstr is None:
       me.mask = -1
     elif maskstr.isdigit():
     if maskstr is None:
       me.mask = -1
     elif maskstr.isdigit():
-      me.mask = (1 << 32) - (1 << 32 - int(maskstr))
+      me.mask = (1 << me.NBITS) - (1 << me.NBITS - int(maskstr))
     else:
     else:
-      me.mask = me._addrstr_to_int(maskstr)
+      me._setmask(maskstr)
     if me.addr&~me.mask:
       raise ValueError('network contains bits set beyond mask')
   def _addrstr_to_int(me, addrstr):
     if me.addr&~me.mask:
       raise ValueError('network contains bits set beyond mask')
   def _addrstr_to_int(me, addrstr):
-    return unpack('>L', S.inet_aton(addrstr))[0]
+    try: return loadb(S.inet_pton(me.AF, addrstr))
+    except S.error: raise ValueError('bad address syntax')
   def _int_to_addrstr(me, n):
   def _int_to_addrstr(me, n):
-    return S.inet_ntoa(pack('>L', n))
+    return S.inet_ntop(me.AF, storeb(me.addr, me.NBITS))
+  def _setmask(me, maskstr):
+    raise ValueError('only prefix masked supported')
+  def _maskstr(me):
+    raise ValueError('only prefix masked supported')
   def sockaddr(me, port = 0):
     if me.mask != -1: raise ValueError('not a simple address')
   def sockaddr(me, port = 0):
     if me.mask != -1: raise ValueError('not a simple address')
-    return me._int_to_addrstr(me.addr), port
+    return me._sockaddr(port)
   def __str__(me):
   def __str__(me):
-    addrstr = me._int_to_addrstr(me.addr)
+    addrstr = me._addrstr()
     if me.mask == -1:
       return addrstr
     else:
     if me.mask == -1:
       return addrstr
     else:
-      inv = me.mask ^ ((1 << 32) - 1)
+      inv = me.mask ^ ((1 << me.NBITS) - 1)
       if (inv&(inv + 1)) == 0:
       if (inv&(inv + 1)) == 0:
-        return '%s/%d' % (addrstr, 32 - inv.bit_length())
+        return '%s/%d' % (addrstr, me.NBITS - inv.bit_length())
       else:
       else:
-        return '%s/%s' % (addrstr, me._int_to_addrstr(me.mask))
+        return '%s/%s' % (addrstr, me._maskstr())
   def withinp(me, net):
   def withinp(me, net):
+    if type(net) != type(me): return False
     if (me.mask&net.mask) != net.mask: return False
     if (me.addr ^ net.addr)&net.mask: return False
     if (me.mask&net.mask) != net.mask: return False
     if (me.addr ^ net.addr)&net.mask: return False
-    return True
+    return me._withinp(net)
   def eq(me, other):
   def eq(me, other):
+    if type(me) != type(other): return False
     if me.mask != other.mask: return False
     if me.addr != other.addr: return False
     if me.mask != other.mask: return False
     if me.addr != other.addr: return False
+    return me._eq(other)
+  def _withinp(me, net):
     return True
     return True
+  def _eq(me, other):
+    return True
+
+class InetAddress (BaseAddress):
+  AF = S.AF_INET
+  AFNAME = 'IPv4'
+  NBITS = 32
+  def _addrstr_to_int(me, addrstr):
+    try: return loadb(S.inet_aton(addrstr))
+    except S.error: raise ValueError('bad address syntax')
+  def _setaddr(me, addrstr):
+    me.addr = me._addrstr_to_int(addrstr)
+  def _setmask(me, maskstr):
+    me.mask = me._addrstr_to_int(maskstr)
+  def _addrstr(me):
+    return me._int_to_addrstr(me.addr)
+  def _maskstr(me):
+    return me._int_to_addrstr(me.mask)
+  def _sockaddr(me, port = 0):
+    return (me._addrstr(), port)
   @classmethod
   def from_sockaddr(cls, sa):
     addr, port = (lambda a, p: (a, p))(*sa)
     return cls(addr), port
 
   @classmethod
   def from_sockaddr(cls, sa):
     addr, port = (lambda a, p: (a, p))(*sa)
     return cls(addr), port
 
+class Inet6Address (BaseAddress):
+  AF = S.AF_INET6
+  AFNAME = 'IPv6'
+  NBITS = 128
+  def _setaddr(me, addrstr):
+    pc = addrstr.find('%')
+    if pc == -1:
+      me.addr = me._addrstr_to_int(addrstr)
+      me.scope = 0
+    else:
+      me.addr = me._addrstr_to_int(addrstr[:pc])
+      ais = S.getaddrinfo(addrstr, 0, S.AF_INET6, S.SOCK_DGRAM, 0,
+                          S.AI_NUMERICHOST | S.AI_NUMERICSERV)
+      me.scope = ais[0][4][3]
+  def _addrstr(me):
+    addrstr = me._int_to_addrstr(me.addr)
+    if me.scope == 0:
+      return addrstr
+    else:
+      name, _ = S.getnameinfo((addrstr, 0, 0, me.scope),
+                              S.NI_NUMERICHOST | S.NI_NUMERICSERV)
+      return name
+  def _sockaddr(me, port = 0):
+    return (me._addrstr(), port, 0, me.scope)
+  @classmethod
+  def from_sockaddr(cls, sa):
+    addr, port, _, scope = (lambda a, p, f = 0, s = 0: (a, p, f, s))(*sa)
+    me = cls(addr)
+    me.scope = scope
+    return me, port
+  def _withinp(me, net):
+    return net.scope == 0 or me.scope == net.scope
+  def _eq(me, other):
+    return me.scope == other.scope
+
 def parse_address(addrstr, maskstr = None):
 def parse_address(addrstr, maskstr = None):
-  return InetAddress(addrstr, maskstr)
+  if addrstr.find(':') >= 0: return Inet6Address(addrstr, maskstr)
+  else: return InetAddress(addrstr, maskstr)
 
 def parse_net(netstr):
   try: sl = netstr.index('/')
 
 def parse_net(netstr):
   try: sl = netstr.index('/')
@@ -156,7 +197,20 @@ def straddr(a): return a is None and '#<none>' or str(a)
 ## this service are largely going to be satellite notes, I don't think
 ## scalability's going to be a problem.
 
 ## this service are largely going to be satellite notes, I don't think
 ## scalability's going to be a problem.
 
-TESTADDR = InetAddress('1.2.3.4')
+TESTADDRS = [InetAddress('1.2.3.4'), Inet6Address('2001::1')]
+
+CONFSYNTAX = [
+  ('COMMENT', RX.compile(r'^\s*($|[;#])')),
+  ('GRPHDR', RX.compile(r'^\s*\[(.*)\]\s*$')),
+  ('ASSGN', RX.compile(r'\s*([\w.-]+)\s*[:=]\s*(|\S|\S.*\S)\s*$'))]
+
+class ConfigError (Exception):
+  def __init__(me, file, lno, msg):
+    me.file = file
+    me.lno = lno
+    me.msg = msg
+  def __str__(me):
+    return '%s:%d: %s' % (me.file, me.lno, me.msg)
 
 class Config (object):
   """
 
 class Config (object):
   """
@@ -164,11 +218,11 @@ class Config (object):
 
   The most interesting thing is probably the `groups' slot, which stores a
   list of pairs (NAME, PATTERNS); the NAME is a string, and the PATTERNS a
 
   The most interesting thing is probably the `groups' slot, which stores a
   list of pairs (NAME, PATTERNS); the NAME is a string, and the PATTERNS a
-  list of (TAG, PEER, NET) triples.  The implication is that there should be
+  list of (TAG, PEER, NETS) triples.  The implication is that there should be
   precisely one peer from the set, and that it should be named TAG, where
   precisely one peer from the set, and that it should be named TAG, where
-  (TAG, PEER, NET) is the first triple such that the host's primary IP
+  (TAG, PEER, NETS) is the first triple such that the host's primary IP
   address (if PEER is None -- or the IP address it would use for
   address (if PEER is None -- or the IP address it would use for
-  communicating with PEER) is within the NET.
+  communicating with PEER) is within one of the networks defined by NETS.
   """
 
   def __init__(me, file):
   """
 
   def __init__(me, file):
@@ -191,74 +245,136 @@ class Config (object):
     Internal function to update the configuration from the underlying file.
     """
 
     Internal function to update the configuration from the underlying file.
     """
 
-    ## Read the configuration.  We have no need of the fancy substitutions,
-    ## so turn them all off.
-    cp = RawConfigParser()
-    cp.read(me._file)
     if T._debug: print '# reread config'
 
     if T._debug: print '# reread config'
 
-    ## Save the test address.  Make sure it's vaguely sensible.  The default
-    ## is probably good for most cases, in fact, since that address isn't
-    ## actually in use.  Note that we never send packets to the test address;
-    ## we just use it to discover routing information.
-    if cp.has_option('DEFAULT', 'test-addr'):
-      testaddr = InetAddress(cp.get('DEFAULT', 'test-addr'))
-    else:
-      testaddr = TESTADDR
-
-    ## Scan the configuration file and build the groups structure.
-    groups = []
-    for sec in cp.sections():
-      pats = []
-      for tag in cp.options(sec):
-        spec = cp.get(sec, tag).split()
-
-        ## Parse the entry into peer and network.
-        if len(spec) == 1:
-          peer = None
-          net = spec[0]
+    ## Initial state.
+    testaddrs = {}
+    groups = {}
+    grpname = None
+    grplist = []
+
+    ## Open the file and start reading.
+    with open(me._file) as f:
+      lno = 0
+      for line in f:
+        lno += 1
+        for tag, rx in CONFSYNTAX:
+          m = rx.match(line)
+          if m: break
         else:
         else:
-          peer = InetAddress(spec[0])
-          net = spec[1]
-
-        ## Syntax of a net is ADDRESS/MASK, where ADDRESS is a dotted-quad,
-        ## and MASK is either a dotted-quad or a single integer N indicating
-        ## a mask with N leading ones followed by trailing zeroes.
-        net = parse_net(net)
-        pats.append((tag, peer, net))
-
-      ## Annoyingly, RawConfigParser doesn't preserve the order of options.
-      ## In order to make things vaguely sane, we topologically sort the
-      ## patterns so that more specific patterns are checked first.
-      pats = list(toposort(lambda (t, p, n), (tt, pp, nn): \
-                             (p and not pp) or \
-                             (p == pp and n.withinp(nn)),
-                           pats))
-      groups.append((sec, pats))
+          raise ConfigError(me._file, lno, 'failed to parse line: %r' % line)
+
+        if tag == 'COMMENT':
+          ## A comment.  Ignore it and hope it goes away.
+
+          continue
+
+        elif tag == 'GRPHDR':
+          ## A group header.  Flush the old group and start a new one.
+          newname = m.group(1)
+
+          if grpname is not None: groups[grpname] = grplist
+          if newname in groups:
+            raise ConfigError(me._file, lno,
+                              "duplicate group name `%s'" % newname)
+          grpname = newname
+          grplist = []
+
+        elif tag == 'ASSGN':
+           ## An assignment.  Deal with it.
+          name, value = m.group(1), m.group(2)
+
+          if grpname is None:
+            ## We're outside of any group, so this is a global configuration
+            ## tweak.
+
+            if name == 'test-addr':
+              for astr in value.split():
+                try:
+                  a = parse_address(astr)
+                except Exception, e:
+                  raise ConfigError(me._file, lno,
+                                    "invalid IP address `%s': %s" %
+                                    (astr, e))
+                if a.AF in testaddrs:
+                  raise ConfigError(me._file, lno,
+                                    'duplicate %s test-address' % a.AFNAME)
+                testaddrs[a.AF] = a
+            else:
+              raise ConfigError(me._file, lno,
+                                "unknown global option `%s'" % name)
+
+          else:
+            ## Parse a pattern and add it to the group.
+            spec = value.split()
+            i = 0
+
+            ## Check for an explicit target address.
+            if i >= len(spec) or spec[i].find('/') >= 0:
+              peer = None
+              af = None
+            else:
+              try:
+                peer = parse_address(spec[i])
+              except Exception, e:
+                raise ConfigError(me._file, lno,
+                                  "invalid IP address `%s': %s" %
+                                  (spec[i], e))
+              af = peer.AF
+              i += 1
+
+            ## Parse the list of local networks.
+            nets = []
+            while i < len(spec):
+              try:
+                net = parse_net(spec[i])
+              except Exception, e:
+                raise ConfigError(me._file, lno,
+                                  "invalid IP network `%s': %s" %
+                                  (spec[i], e))
+              else:
+                nets.append(net)
+              i += 1
+            if not nets:
+              raise ConfigError(me._file, lno, 'no networks defined')
+
+            ## Make sure that the addresses are consistent.
+            for net in nets:
+              if af is None:
+                af = net.AF
+              elif net.AF != af:
+                raise ConfigError(me._file, lno,
+                                  "net %s doesn't match" % net)
+
+            ## Add this entry to the list.
+            grplist.append((name, peer, nets))
+
+    ## Fill in the default test addresses if necessary.
+    for a in TESTADDRS: testaddrs.setdefault(a.AF, a)
 
     ## Done.
 
     ## Done.
-    me.testaddr = testaddr
+    if grpname is not None: groups[grpname] = grplist
+    me.testaddrs = testaddrs
     me.groups = groups
 
 ### This will be a configuration file.
 CF = None
 
 def cmd_showconfig():
     me.groups = groups
 
 ### This will be a configuration file.
 CF = None
 
 def cmd_showconfig():
-  T.svcinfo('test-addr=%s' % CF.testaddr)
+  T.svcinfo('test-addr=%s' %
+            ' '.join(str(a)
+                     for a in sorted(CF.testaddrs.itervalues(),
+                                     key = lambda a: a.AFNAME)))
 def cmd_showgroups():
 def cmd_showgroups():
-  for sec, pats in CF.groups:
-    T.svcinfo(sec)
+  for g in sorted(CF.groups.iterkeys()):
+    T.svcinfo(g)
 def cmd_showgroup(g):
 def cmd_showgroup(g):
-  for s, p in CF.groups:
-    if s == g:
-      pats = p
-      break
-  else:
-    raise T.TripeJobError('unknown-group', g)
-  for t, p, n in pats:
+  try: pats = CF.groups[g]
+  except KeyError: raise T.TripeJobError('unknown-group', g)
+  for t, p, nn in pats:
     T.svcinfo('peer', t,
               'target', p and str(p) or '(default)',
     T.svcinfo('peer', t,
               'target', p and str(p) or '(default)',
-              'net', str(n))
+              'net', ' '.join(map(str, nn)))
 
 ###--------------------------------------------------------------------------
 ### Responding to a network up/down event.
 
 ###--------------------------------------------------------------------------
 ### Responding to a network up/down event.
@@ -267,12 +383,12 @@ def localaddr(peer):
   """
   Return the local IP address used for talking to PEER.
   """
   """
   Return the local IP address used for talking to PEER.
   """
-  sk = S.socket(S.AF_INET, S.SOCK_DGRAM)
+  sk = S.socket(peer.AF, S.SOCK_DGRAM)
   try:
     try:
       sk.connect(peer.sockaddr(1))
       addr = sk.getsockname()
   try:
     try:
       sk.connect(peer.sockaddr(1))
       addr = sk.getsockname()
-      return InetAddress.from_sockaddr(addr)[0]
+      return type(peer).from_sockaddr(addr)[0]
     except S.error:
       return None
   finally:
     except S.error:
       return None
   finally:
@@ -327,53 +443,52 @@ def kickpeers():
     ## Find the current list of peers.
     peers = SM.list()
 
     ## Find the current list of peers.
     peers = SM.list()
 
-    ## Work out the primary IP address.
+    ## Work out the primary IP addresses.
+    locals = {}
     if upness:
     if upness:
-      addr = localaddr(CF.testaddr)
-      if addr is None:
-        upness = False
-    else:
-      addr = None
+      for af, remote in CF.testaddrs.iteritems():
+        local = localaddr(remote)
+        if local is not None: locals[af] = local
+      if not locals: upness = False
     if not T._debug: pass
     if not T._debug: pass
-    elif addr: print '#   local address = %s' % straddr(addr)
-    else: print '#   offline'
+    elif not locals: print '#   offline'
+    else:
+      for local in locals.itervalues():
+        print '#   local %s address = %s' % (local.AFNAME, local)
 
     ## Now decide what to do.
     changes = []
 
     ## Now decide what to do.
     changes = []
-    for g, pp in CF.groups:
+    for g, pp in CF.groups.iteritems():
       if T._debug: print '#   check group %s' % g
 
       ## Find out which peer in the group ought to be active.
       if T._debug: print '#   check group %s' % g
 
       ## Find out which peer in the group ought to be active.
-      ip = None
-      map = {}
+      statemap = {}
       want = None
       want = None
-      for t, p, n in pp:
-        if p is None or not upness:
-          ipq = addr
-        else:
-          ipq = localaddr(p)
+      matchp = False
+      for t, p, nn in pp:
+        af = nn[0].AF
+        if p is None or not upness: ip = locals.get(af)
+        else: ip = localaddr(p)
         if T._debug:
         if T._debug:
-          info = 'peer=%s; target=%s; net=%s; local=%s' % (
-            t, p or '(default)', n, straddr(ipq))
-        if upness and ip is None and \
-              ipq is not None and ipq.withinp(n):
+          info = 'peer = %s; target = %s; nets = %s; local = %s' % (
+            t, p or '(default)', ', '.join(map(str, nn)), straddr(ip))
+        if upness and not matchp and \
+           ip is not None and any(ip.withinp(n) for n in nn):
           if T._debug: print '#     %s: SELECTED' % info
           if T._debug: print '#     %s: SELECTED' % info
-          map[t] = 'up'
+          statemap[t] = 'up'
           select.append('%s=%s' % (g, t))
           select.append('%s=%s' % (g, t))
-          if t == 'down' or t.startswith('down/'):
-            want = None
-          else:
-            want = t
-          ip = ipq
+          if t == 'down' or t.startswith('down/'): want = None
+          else: want = t
+          matchp = True
         else:
         else:
-          map[t] = 'down'
+          statemap[t] = 'down'
           if T._debug: print '#     %s: skipped' % info
 
       ## Shut down the wrong ones.
       found = False
           if T._debug: print '#     %s: skipped' % info
 
       ## Shut down the wrong ones.
       found = False
-      if T._debug: print '#   peer-map = %r' % map
+      if T._debug: print '#   peer-map = %r' % statemap
       for p in peers:
       for p in peers:
-        what = map.get(p, 'leave')
+        what = statemap.get(p, 'leave')
         if what == 'up':
           found = True
           if T._debug: print '#   peer %s: already up' % p
         if what == 'up':
           found = True
           if T._debug: print '#   peer %s: already up' % p