chiark / gitweb /
less absurd
[hippotat.git] / hippotat / __init__.py
1 # -*- python -*-
2
3 import signal
4 signal.signal(signal.SIGINT, signal.SIG_DFL)
5
6 import sys
7
8 import twisted
9 from twisted.internet import reactor
10 import twisted.internet.endpoints
11 import twisted.logger
12 from twisted.logger import LogLevel
13 import twisted.python.constants
14 from twisted.python.constants import NamedConstant
15
16 import ipaddress
17 from ipaddress import AddressValueError
18
19 from optparse import OptionParser
20 from configparser import ConfigParser
21 from configparser import NoOptionError
22
23 import collections
24 import time
25 import codecs
26 import traceback
27
28 import re as regexp
29
30 import hippotat.slip as slip
31
32 class DBG(twisted.python.constants.Names):
33   ROUTE = NamedConstant()
34   DROP = NamedConstant()
35   FLOW = NamedConstant()
36   HTTP = NamedConstant()
37   HTTP_CTRL = NamedConstant()
38   INIT = NamedConstant()
39   QUEUE = NamedConstant()
40   QUEUE_CTRL = NamedConstant()
41   HTTP_FULL = NamedConstant()
42   SLIP_FULL = NamedConstant()
43   CTRL_DUMP = NamedConstant()
44
45 _hex_codec = codecs.getencoder('hex_codec')
46
47 log = twisted.logger.Logger()
48
49 def log_debug(dflag, msg, idof=None, d=None):
50   #print('---------------->',repr((dflag, msg, idof, d)), file=sys.stderr)
51   if idof is not None:
52     msg = '[%#x] %s' % (id(idof), msg)
53   if d is not None:
54     d = d[0:64]
55     d = _hex_codec(d)[0].decode('ascii')
56     msg += ' ' + d
57   log.info('{dflag} {msgcore}', dflag=dflag, msgcore=msg)
58
59 defcfg = '''
60 [DEFAULT]
61 #[<client>] overrides
62 max_batch_down = 65536           # used by server, subject to [limits]
63 max_queue_time = 10              # used by server, subject to [limits]
64 max_request_time = 54            # used by server, subject to [limits]
65 target_requests_outstanding = 3  # must match; subject to [limits] on server
66 max_requests_outstanding = 4     # used by client
67 max_batch_up = 4000              # used by client
68 http_timeout = 30                # used by client
69 http_retry = 5                   # used by client
70
71 #[server] or [<client>] overrides
72 ipif = userv root ipif %(local)s,%(peer)s,%(mtu)s,slip %(rnets)s
73 # extra interpolations:  %(local)s        %(peer)s          %(rnet)s
74 #  obtained   on server  [virtual]server  [virtual]relay    [virtual]network
75 #      from   on client  <client>         [virtual]server   [virtual]routes
76
77 [virtual]
78 mtu = 1500
79 routes = ''
80 # network = <prefix>/<len>  # mandatory for server
81 # server  = <ipaddr>   # used by both, default is computed from `network'
82 # relay   = <ipaddr>   # used by server, default from `network' and `server'
83 #  default server is first host in network
84 #  default relay is first host which is not server
85
86 [server]
87 # addrs = 127.0.0.1 ::1    # mandatory for server
88 port = 80                  # used by server
89 # url              # used by client; default from first `addrs' and `port'
90
91 # [<client-ip4-or-ipv6-address>]
92 # password = <password>    # used by both, must match
93
94 [limits]
95 max_batch_down = 262144           # used by server
96 max_queue_time = 121              # used by server
97 max_request_time = 121            # used by server
98 target_requests_outstanding = 10  # used by server
99 '''
100
101 # these need to be defined here so that they can be imported by import *
102 cfg = ConfigParser()
103 optparser = OptionParser()
104
105 _mimetrans = bytes.maketrans(b'-'+slip.esc, slip.esc+b'-')
106 def mime_translate(s):
107   # SLIP-encoded packets cannot contain ESC ESC.
108   # Swap `-' and ESC.  The result cannot contain `--'
109   return s.translate(_mimetrans)
110
111 class ConfigResults:
112   def __init__(self, d = { }):
113     self.__dict__ = d
114   def __repr__(self):
115     return 'ConfigResults('+repr(self.__dict__)+')'
116
117 c = ConfigResults()
118
119 def log_discard(packet, iface, saddr, daddr, why):
120   log_debug(DBG.DROP,
121             'discarded packet [%s] %s -> %s: %s' % (iface, saddr, daddr, why),
122             d=packet)
123
124 #---------- packet parsing ----------
125
126 def packet_addrs(packet):
127   version = packet[0] >> 4
128   if version == 4:
129     addrlen = 4
130     saddroff = 3*4
131     factory = ipaddress.IPv4Address
132   elif version == 6:
133     addrlen = 16
134     saddroff = 2*4
135     factory = ipaddress.IPv6Address
136   else:
137     raise ValueError('unsupported IP version %d' % version)
138   saddr = factory(packet[ saddroff           : saddroff + addrlen   ])
139   daddr = factory(packet[ saddroff + addrlen : saddroff + addrlen*2 ])
140   return (saddr, daddr)
141
142 #---------- address handling ----------
143
144 def ipaddr(input):
145   try:
146     r = ipaddress.IPv4Address(input)
147   except AddressValueError:
148     r = ipaddress.IPv6Address(input)
149   return r
150
151 def ipnetwork(input):
152   try:
153     r = ipaddress.IPv4Network(input)
154   except NetworkValueError:
155     r = ipaddress.IPv6Network(input)
156   return r
157
158 #---------- ipif (SLIP) subprocess ----------
159
160 class SlipStreamDecoder():
161   def __init__(self, desc, on_packet):
162     self._buffer = b''
163     self._on_packet = on_packet
164     self._desc = desc
165     self._log('__init__')
166
167   def _log(self, msg, **kwargs):
168     log_debug(DBG.SLIP_FULL, 'slip %s: %s' % (self._desc, msg), **kwargs)
169
170   def inputdata(self, data):
171     self._log('inputdata', d=data)
172     data = self._buffer + data
173     self._buffer = b''
174     packets = slip.decode(data)
175     self._buffer = packets.pop()
176     for packet in packets:
177       self._maybe_packet(packet)
178     self._log('bufremain', d=self._buffer)
179
180   def _maybe_packet(self, packet):
181     self._log('maybepacket', d=packet)
182     if len(packet):
183       self._on_packet(packet)
184
185   def flush(self):
186     self._log('flush')
187     self._maybe_packet(self._buffer)
188     self._buffer = b''
189
190 class _IpifProcessProtocol(twisted.internet.protocol.ProcessProtocol):
191   def __init__(self, router):
192     self._router = router
193     self._decoder = SlipStreamDecoder('ipif', self.slip_on_packet)
194   def connectionMade(self): pass
195   def outReceived(self, data):
196     self._decoder.inputdata(data)
197   def slip_on_packet(self, packet):
198     (saddr, daddr) = packet_addrs(packet)
199     if saddr.is_link_local or daddr.is_link_local:
200       log_discard(packet, 'ipif', saddr, daddr, 'link-local')
201       return
202     self._router(packet, saddr, daddr)
203   def processEnded(self, status):
204     status.raiseException()
205
206 def start_ipif(command, router):
207   global ipif
208   ipif = _IpifProcessProtocol(router)
209   reactor.spawnProcess(ipif,
210                        '/bin/sh',['sh','-xc', command],
211                        childFDs={0:'w', 1:'r', 2:2},
212                        env=None)
213
214 def queue_inbound(packet):
215   log_debug(DBG.FLOW, "queue_inbound", d=packet)
216   ipif.transport.write(slip.delimiter)
217   ipif.transport.write(slip.encode(packet))
218   ipif.transport.write(slip.delimiter)
219
220 #---------- packet queue ----------
221
222 class PacketQueue():
223   def __init__(self, desc, max_queue_time):
224     self._desc = desc
225     assert(desc + '')
226     self._max_queue_time = max_queue_time
227     self._pq = collections.deque() # packets
228
229   def _log(self, dflag, msg, **kwargs):
230     log_debug(dflag, self._desc+' pq: '+msg, **kwargs)
231
232   def append(self, packet):
233     self._log(DBG.QUEUE, 'append', d=packet)
234     self._pq.append((time.monotonic(), packet))
235
236   def nonempty(self):
237     self._log(DBG.QUEUE, 'nonempty ?')
238     while True:
239       try: (queuetime, packet) = self._pq[0]
240       except IndexError:
241         self._log(DBG.QUEUE, 'nonempty ? empty.')
242         return False
243
244       age = time.monotonic() - queuetime
245       if age > self._max_queue_time:
246         # strip old packets off the front
247         self._log(DBG.QUEUE, 'dropping (old)', d=packet)
248         self._pq.popleft()
249         continue
250
251       self._log(DBG.QUEUE, 'nonempty ? nonempty.')
252       return True
253
254   def process(self, sizequery, moredata, max_batch):
255     # sizequery() should return size of batch so far
256     # moredata(s) should add s to batch
257     self._log(DBG.QUEUE, 'process...')
258     while True:
259       try: (dummy, packet) = self._pq[0]
260       except IndexError:
261         self._log(DBG.QUEUE, 'process... empty')
262         break
263
264       self._log(DBG.QUEUE_CTRL, 'process... packet', d=packet)
265
266       encoded = slip.encode(packet)
267       sofar = sizequery()  
268
269       self._log(DBG.QUEUE_CTRL,
270                 'process... (sofar=%d, max=%d) encoded' % (sofar, max_batch),
271                 d=encoded)
272
273       if sofar > 0:
274         if sofar + len(slip.delimiter) + len(encoded) > max_batch:
275           self._log(DBG.QUEUE_CTRL, 'process... overflow')
276           break
277         moredata(slip.delimiter)
278
279       moredata(encoded)
280       self._pq.popleft()
281
282 #---------- error handling ----------
283
284 _crashing = False
285
286 def crash(err):
287   global _crashing
288   _crashing = True
289   print('========== CRASH ==========', err,
290         '===========================', file=sys.stderr)
291   try: reactor.stop()
292   except twisted.internet.error.ReactorNotRunning: pass
293
294 def crash_on_defer(defer):
295   defer.addErrback(lambda err: crash(err))
296
297 def crash_on_critical(event):
298   if event.get('log_level') >= LogLevel.critical:
299     crash(twisted.logger.formatEvent(event))
300
301 #---------- config processing ----------
302
303 def process_cfg_common_always():
304   global mtu
305   c.mtu = cfg.get('virtual','mtu')
306
307 def process_cfg_ipif(section, varmap):
308   for d, s in varmap:
309     try: v = getattr(c, s)
310     except AttributeError: continue
311     setattr(c, d, v)
312
313   #print(repr(c))
314
315   c.ipif_command = cfg.get(section,'ipif', vars=c.__dict__)
316
317 def process_cfg_network():
318   c.network = ipnetwork(cfg.get('virtual','network'))
319   if c.network.num_addresses < 3 + 2:
320     raise ValueError('network needs at least 2^3 addresses')
321
322 def process_cfg_server():
323   try:
324     c.server = cfg.get('virtual','server')
325   except NoOptionError:
326     process_cfg_network()
327     c.server = next(c.network.hosts())
328
329 class ServerAddr():
330   def __init__(self, port, addrspec):
331     self.port = port
332     # also self.addr
333     try:
334       self.addr = ipaddress.IPv4Address(addrspec)
335       self._endpointfactory = twisted.internet.endpoints.TCP4ServerEndpoint
336       self._inurl = b'%s'
337     except AddressValueError:
338       self.addr = ipaddress.IPv6Address(addrspec)
339       self._endpointfactory = twisted.internet.endpoints.TCP6ServerEndpoint
340       self._inurl = b'[%s]'
341   def make_endpoint(self):
342     return self._endpointfactory(reactor, self.port, self.addr)
343   def url(self):
344     url = b'http://' + (self._inurl % str(self.addr).encode('ascii'))
345     if self.port != 80: url += b':%d' % self.port
346     url += b'/'
347     return url
348     
349 def process_cfg_saddrs():
350   try: port = cfg.getint('server','port')
351   except NoOptionError: port = 80
352
353   c.saddrs = [ ]
354   for addrspec in cfg.get('server','addrs').split():
355     sa = ServerAddr(port, addrspec)
356     c.saddrs.append(sa)
357
358 def process_cfg_clients(constructor):
359   c.clients = [ ]
360   for cs in cfg.sections():
361     if not (':' in cs or '.' in cs): continue
362     ci = ipaddr(cs)
363     pw = cfg.get(cs, 'password')
364     pw = pw.encode('utf-8')
365     constructor(ci,cs,pw)
366
367 #---------- startup ----------
368
369 def common_startup():
370   log_formatter = twisted.logger.formatEventAsClassicLogText
371   log_observer = twisted.logger.FileLogObserver(sys.stderr, log_formatter)
372   twisted.logger.globalLogBeginner.beginLoggingTo(
373     [ log_observer, crash_on_critical ]
374     )
375
376   optparser.add_option('-c', '--config', dest='configfile',
377                        default='/etc/hippotat/config')
378   (opts, args) = optparser.parse_args()
379   if len(args): optparser.error('no non-option arguments please')
380
381   re = regexp.compile('#.*')
382   cfg.read_string(re.sub('', defcfg))
383   cfg.read(opts.configfile)
384
385 def common_run():
386   log_debug(DBG.INIT, 'entering reactor')
387   if not _crashing: reactor.run()
388   print('CRASHED (end)', file=sys.stderr)