X-Git-Url: http://www.chiark.greenend.org.uk/ucgi/~ian/git?a=blobdiff_plain;f=hippotat%2F__init__.py;h=6f0d3e87d8a2ba2dcf5f7d4e1f7ea154b716f22e;hb=88487243bc0be906c63258005df75b96bc8165a5;hp=ee844cd53faa83b09602c3231043b2d9fb588861;hpb=c491fea13a428b0c33df3294b23db7e2773e8dc6;p=hippotat.git diff --git a/hippotat/__init__.py b/hippotat/__init__.py index ee844cd..6f0d3e8 100644 --- a/hippotat/__init__.py +++ b/hippotat/__init__.py @@ -1,10 +1,35 @@ # -*- python -*- -import hippotat.slip as slip +import signal +signal.signal(signal.SIGINT, signal.SIG_DFL) + +import twisted +from twisted.internet import reactor +from twisted.logger import LogLevel import ipaddress from ipaddress import AddressValueError +import hippotat.slip as slip + +from optparse import OptionParser +from configparser import ConfigParser +from configparser import NoOptionError + +import collections + +# these need to be defined here so that they can be imported by import * +cfg = ConfigParser() +optparser = OptionParser() + +class ConfigResults: + def __init__(self, d = { }): + self.__dict__ = d + def __repr__(self): + return 'ConfigResults('+repr(self.__dict__)+')' + +c = ConfigResults() + #---------- packet parsing ---------- def packet_addrs(packet): @@ -38,3 +63,158 @@ def ipnetwork(input): except NetworkValueError: r = ipaddress.IPv6Network(input) return r + +#---------- ipif (SLIP) subprocess ---------- + +class _IpifProcessProtocol(twisted.internet.protocol.ProcessProtocol): + def __init__(self, router): + self._buffer = b'' + self._router = router + def connectionMade(self): pass + def outReceived(self, data): + #print('RECV ', repr(data)) + self._buffer += data + packets = slip.decode(self._buffer) + self._buffer = packets.pop() + for packet in packets: + if not len(packet): continue + (saddr, daddr) = packet_addrs(packet) + self._router(packet, saddr, daddr) + def processEnded(self, status): + status.raiseException() + +def start_ipif(command, router): + global ipif + ipif = _IpifProcessProtocol(router) + reactor.spawnProcess(ipif, + '/bin/sh',['sh','-xc', command], + childFDs={0:'w', 1:'r', 2:2}) + +def queue_inbound(packet): + ipif.transport.write(slip.delimiter) + ipif.transport.write(slip.encode(packet)) + ipif.transport.write(slip.delimiter) + +#---------- packet queue ---------- + +class PacketQueue(): + def __init__(self, max_queue_time): + self._max_queue_time = max_queue_time + self._pq = collections.deque() # packets + + def append(self, packet): + self._pq.append((time.monotonic(), packet)) + + def nonempty(self): + while True: + try: (queuetime, packet) = self._pq[0] + except IndexError: return False + + age = time.monotonic() - queuetime + if age > self.max_queue_time: + # strip old packets off the front + self._pq.popleft() + continue + + return True + + def popleft(self): + # caller must have checked nonempty + try: (dummy, packet) = self._pq[0] + except IndexError: return None + return packet + +#---------- error handling ---------- + +def crash(err): + print('CRASH ', err, file=sys.stderr) + try: reactor.stop() + except twisted.internet.error.ReactorNotRunning: pass + +def crash_on_defer(defer): + defer.addErrback(lambda err: crash(err)) + +def crash_on_critical(event): + if event.get('log_level') >= LogLevel.critical: + crash(twisted.logger.formatEvent(event)) + +#---------- config processing ---------- + +def process_cfg_common_always(): + global mtu + c.mtu = cfg.get('virtual','mtu') + +def process_cfg_ipif(section, varmap): + for d, s in varmap: + try: v = getattr(c, s) + except KeyError: pass + setattr(c, d, v) + + print(repr(c)) + + c.ipif_command = cfg.get(section,'ipif', vars=c.__dict__) + +def process_cfg_network(): + c.network = ipnetwork(cfg.get('virtual','network')) + if c.network.num_addresses < 3 + 2: + raise ValueError('network needs at least 2^3 addresses') + +def process_cfg_server(): + try: + c.server = cfg.get('virtual','server') + except NoOptionError: + process_cfg_network() + c.server = next(c.network.hosts()) + +class ServerAddr(): + def __init__(self, port, addrspec): + self.port = port + # also self.addr + try: + self.addr = ipaddress.IPv4Address(addrspec) + self._endpointfactory = twisted.internet.endpoints.TCP4ServerEndpoint + self._inurl = '%s' + except AddressValueError: + self.addr = ipaddress.IPv6Address(addrspec) + self._endpointfactory = twisted.internet.endpoints.TCP6ServerEndpoint + self._inurl = '[%s]' + def make_endpoint(self): + return self._endpointfactory(reactor, self.port, self.addr) + def url(self): + url = 'http://' + (self._inurl % self.addr) + if self.port != 80: url += ':%d' % self.port + url += '/' + return url + +def process_cfg_saddrs(): + port = cfg.getint('server','port') + + c.saddrs = [ ] + for addrspec in cfg.get('server','addrs').split(): + sa = ServerAddr(port, addrspec) + c.saddrs.append(sa) + +def process_cfg_clients(constructor): + c.clients = [ ] + for cs in cfg.sections(): + if not (':' in cs or '.' in cs): continue + ci = ipaddr(cs) + pw = cfg.get(cs, 'password') + constructor(ci,cs,pw) + +#---------- startup ---------- + +def common_startup(defcfg): + twisted.logger.globalLogPublisher.addObserver(crash_on_critical) + + optparser.add_option('-c', '--config', dest='configfile', + default='/etc/hippotat/config') + (opts, args) = optparser.parse_args() + if len(args): optparser.error('no non-option arguments please') + + cfg.read_string(defcfg) + cfg.read(opts.configfile) + +def common_run(): + reactor.run() + print('CRASHED (end)', file=sys.stderr)