X-Git-Url: http://www.chiark.greenend.org.uk/ucgi/~ian/git?p=hippotat.git;a=blobdiff_plain;f=hippotat%2F__init__.py;h=ae13eece753bb87b17131c48a32f3173aabe1ef8;hp=ee844cd53faa83b09602c3231043b2d9fb588861;hb=db6ba5840b00c61fe6e576b54179ddeb30202b0e;hpb=c491fea13a428b0c33df3294b23db7e2773e8dc6 diff --git a/hippotat/__init__.py b/hippotat/__init__.py index ee844cd..ae13eec 100644 --- a/hippotat/__init__.py +++ b/hippotat/__init__.py @@ -1,10 +1,125 @@ # -*- python -*- -import hippotat.slip as slip +import signal +signal.signal(signal.SIGINT, signal.SIG_DFL) + +import sys + +import twisted +from twisted.internet import reactor +import twisted.internet.endpoints +import twisted.logger +from twisted.logger import LogLevel +import twisted.python.constants +from twisted.python.constants import NamedConstant import ipaddress from ipaddress import AddressValueError +from optparse import OptionParser +from configparser import ConfigParser +from configparser import NoOptionError + +import collections +import time +import codecs +import traceback + +import re as regexp + +import hippotat.slip as slip + +class DBG(twisted.python.constants.Names): + ROUTE = NamedConstant() + DROP = NamedConstant() + FLOW = NamedConstant() + HTTP = NamedConstant() + HTTP_CTRL = NamedConstant() + INIT = NamedConstant() + QUEUE = NamedConstant() + QUEUE_CTRL = NamedConstant() + HTTP_FULL = NamedConstant() + SLIP_FULL = NamedConstant() + +_hex_codec = codecs.getencoder('hex_codec') + +log = twisted.logger.Logger() + +def log_debug(dflag, msg, idof=None, d=None): + #print('---------------->',repr((dflag, msg, idof, d)), file=sys.stderr) + if idof is not None: + msg = '[%d] %s' % (id(idof), msg) + if d is not None: + #d = d[0:64] + d = _hex_codec(d)[0].decode('ascii') + msg += ' ' + d + log.info('{dflag} {msgcore}', dflag=dflag, msgcore=msg) + +defcfg = ''' +[DEFAULT] +#[] overrides +max_batch_down = 65536 # used by server, subject to [limits] +max_queue_time = 10 # used by server, subject to [limits] +max_request_time = 54 # used by server, subject to [limits] +target_requests_outstanding = 3 # must match; subject to [limits] on server +max_requests_outstanding = 4 # used by client +max_batch_up = 4000 # used by client +http_timeout = 30 # used by client +http_retry = 5 # used by client + +#[server] or [] overrides +ipif = userv root ipif %(local)s,%(peer)s,%(mtu)s,slip %(rnets)s +# extra interpolations: %(local)s %(peer)s %(rnet)s +# obtained on server [virtual]server [virtual]relay [virtual]network +# from on client [virtual]server [virtual]routes + +[virtual] +mtu = 1500 +routes = '' +# network = / # mandatory for server +# server = # used by both, default is computed from `network' +# relay = # used by server, default from `network' and `server' +# default server is first host in network +# default relay is first host which is not server + +[server] +# addrs = 127.0.0.1 ::1 # mandatory for server +port = 80 # used by server +# url # used by client; default from first `addrs' and `port' + +# [] +# password = # used by both, must match + +[limits] +max_batch_down = 262144 # used by server +max_queue_time = 121 # used by server +max_request_time = 121 # used by server +target_requests_outstanding = 10 # used by server +''' + +# these need to be defined here so that they can be imported by import * +cfg = ConfigParser() +optparser = OptionParser() + +_mimetrans = bytes.maketrans(b'-'+slip.esc, slip.esc+b'-') +def mime_translate(s): + # SLIP-encoded packets cannot contain ESC ESC. + # Swap `-' and ESC. The result cannot contain `--' + return s.translate(_mimetrans) + +class ConfigResults: + def __init__(self, d = { }): + self.__dict__ = d + def __repr__(self): + return 'ConfigResults('+repr(self.__dict__)+')' + +c = ConfigResults() + +def log_discard(packet, iface, saddr, daddr, why): + log_debug(DBG.DROP, + 'discarded packet [%s] %s -> %s: %s' % (iface, saddr, daddr, why), + d=packet) + #---------- packet parsing ---------- def packet_addrs(packet): @@ -38,3 +153,234 @@ def ipnetwork(input): except NetworkValueError: r = ipaddress.IPv6Network(input) return r + +#---------- ipif (SLIP) subprocess ---------- + +class SlipStreamDecoder(): + def __init__(self, desc, on_packet): + self._buffer = b'' + self._on_packet = on_packet + self._desc = desc + self._log('__init__') + + def _log(self, msg, **kwargs): + log_debug(DBG.SLIP_FULL, 'slip '+msg, **kwargs) + + def inputdata(self, data): + self._log('inputdata', d=data) + data = self._buffer + data + self._buffer = b'' + packets = slip.decode(data) + self._buffer = packets.pop() + for packet in packets: + self._maybe_packet(packet) + self._log('inputdata bufremain', d=self._buffer) + + def _maybe_packet(self, packet): + self._log('inputdata maybepacket', d=packet) + if len(packet): + self._on_packet(packet) + + def flush(self): + self._log('inputdata flush') + self._maybe_packet(self._buffer) + self._buffer = b'' + +class _IpifProcessProtocol(twisted.internet.protocol.ProcessProtocol): + def __init__(self, router): + self._router = router + self._decoder = SlipStreamDecoder('ipif', self.slip_on_packet) + def connectionMade(self): pass + def outReceived(self, data): + self._decoder.inputdata(data) + def slip_on_packet(self, packet): + (saddr, daddr) = packet_addrs(packet) + if saddr.is_link_local or daddr.is_link_local: + log_discard(packet, 'ipif', saddr, daddr, 'link-local') + return + self._router(packet, saddr, daddr) + def processEnded(self, status): + status.raiseException() + +def start_ipif(command, router): + global ipif + ipif = _IpifProcessProtocol(router) + reactor.spawnProcess(ipif, + '/bin/sh',['sh','-xc', command], + childFDs={0:'w', 1:'r', 2:2}, + env=None) + +def queue_inbound(packet): + log_debug(DBG.FLOW, "queue_inbound", d=packet) + ipif.transport.write(slip.delimiter) + ipif.transport.write(slip.encode(packet)) + ipif.transport.write(slip.delimiter) + +#---------- packet queue ---------- + +class PacketQueue(): + def __init__(self, desc, max_queue_time): + self._desc = desc + assert(desc + '') + self._max_queue_time = max_queue_time + self._pq = collections.deque() # packets + + def _log(self, dflag, msg, **kwargs): + log_debug(dflag, self._desc+' pq: '+msg, **kwargs) + + def append(self, packet): + self._log(DBG.QUEUE, 'append', d=packet) + self._pq.append((time.monotonic(), packet)) + + def nonempty(self): + self._log(DBG.QUEUE, 'nonempty ?') + while True: + try: (queuetime, packet) = self._pq[0] + except IndexError: + self._log(DBG.QUEUE, 'nonempty ? empty.') + return False + + age = time.monotonic() - queuetime + if age > self._max_queue_time: + # strip old packets off the front + self._log(DBG.QUEUE, 'dropping (old)', d=packet) + self._pq.popleft() + continue + + self._log(DBG.QUEUE, 'nonempty ? nonempty.') + return True + + def process(self, sizequery, moredata, max_batch): + # sizequery() should return size of batch so far + # moredata(s) should add s to batch + self._log(DBG.QUEUE, 'process...') + while True: + try: (dummy, packet) = self._pq[0] + except IndexError: + self._log(DBG.QUEUE, 'process... empty') + break + + self._log(DBG.QUEUE_CTRL, 'process... packet', d=packet) + + encoded = slip.encode(packet) + sofar = sizequery() + + self._log(DBG.QUEUE_CTRL, + 'process... (sofar=%d, max=%d) encoded' % (sofar, max_batch), + d=encoded) + + if sofar > 0: + if sofar + len(slip.delimiter) + len(encoded) > max_batch: + self._log(DBG.QUEUE_CTRL, 'process... overflow') + break + moredata(slip.delimiter) + + moredata(encoded) + self._pq.popleft() + +#---------- error handling ---------- + +_crashing = False + +def crash(err): + global _crashing + _crashing = True + print('CRASH ', err, file=sys.stderr) + try: reactor.stop() + except twisted.internet.error.ReactorNotRunning: pass + +def crash_on_defer(defer): + defer.addErrback(lambda err: crash(err)) + +def crash_on_critical(event): + if event.get('log_level') >= LogLevel.critical: + crash(twisted.logger.formatEvent(event)) + +#---------- config processing ---------- + +def process_cfg_common_always(): + global mtu + c.mtu = cfg.get('virtual','mtu') + +def process_cfg_ipif(section, varmap): + for d, s in varmap: + try: v = getattr(c, s) + except AttributeError: continue + setattr(c, d, v) + + #print(repr(c)) + + c.ipif_command = cfg.get(section,'ipif', vars=c.__dict__) + +def process_cfg_network(): + c.network = ipnetwork(cfg.get('virtual','network')) + if c.network.num_addresses < 3 + 2: + raise ValueError('network needs at least 2^3 addresses') + +def process_cfg_server(): + try: + c.server = cfg.get('virtual','server') + except NoOptionError: + process_cfg_network() + c.server = next(c.network.hosts()) + +class ServerAddr(): + def __init__(self, port, addrspec): + self.port = port + # also self.addr + try: + self.addr = ipaddress.IPv4Address(addrspec) + self._endpointfactory = twisted.internet.endpoints.TCP4ServerEndpoint + self._inurl = b'%s' + except AddressValueError: + self.addr = ipaddress.IPv6Address(addrspec) + self._endpointfactory = twisted.internet.endpoints.TCP6ServerEndpoint + self._inurl = b'[%s]' + def make_endpoint(self): + return self._endpointfactory(reactor, self.port, self.addr) + def url(self): + url = b'http://' + (self._inurl % str(self.addr).encode('ascii')) + if self.port != 80: url += b':%d' % self.port + url += b'/' + return url + +def process_cfg_saddrs(): + try: port = cfg.getint('server','port') + except NoOptionError: port = 80 + + c.saddrs = [ ] + for addrspec in cfg.get('server','addrs').split(): + sa = ServerAddr(port, addrspec) + c.saddrs.append(sa) + +def process_cfg_clients(constructor): + c.clients = [ ] + for cs in cfg.sections(): + if not (':' in cs or '.' in cs): continue + ci = ipaddr(cs) + pw = cfg.get(cs, 'password') + pw = pw.encode('utf-8') + constructor(ci,cs,pw) + +#---------- startup ---------- + +def common_startup(): + log_formatter = twisted.logger.formatEventAsClassicLogText + log_observer = twisted.logger.FileLogObserver(sys.stderr, log_formatter) + twisted.logger.globalLogBeginner.beginLoggingTo( + [ log_observer, crash_on_critical ] + ) + + optparser.add_option('-c', '--config', dest='configfile', + default='/etc/hippotat/config') + (opts, args) = optparser.parse_args() + if len(args): optparser.error('no non-option arguments please') + + re = regexp.compile('#.*') + cfg.read_string(re.sub('', defcfg)) + cfg.read(opts.configfile) + +def common_run(): + log_debug(DBG.INIT, 'entering reactor') + if not _crashing: reactor.run() + print('CRASHED (end)', file=sys.stderr)