chiark / gitweb /
6b87da9f63919f64979e3bf8680bc361a99abdc3
[hippotat.git] / hippotat / __init__.py
1 # -*- python -*-
2
3 import signal
4 signal.signal(signal.SIGINT, signal.SIG_DFL)
5
6 import sys
7 import os
8
9 from zope.interface import implementer
10
11 import twisted
12 from twisted.internet import reactor
13 import twisted.internet.endpoints
14 import twisted.logger
15 from twisted.logger import LogLevel
16 import twisted.python.constants
17 from twisted.python.constants import NamedConstant
18
19 import ipaddress
20 from ipaddress import AddressValueError
21
22 from optparse import OptionParser
23 from configparser import ConfigParser
24 from configparser import NoOptionError
25
26 from functools import partial
27
28 import collections
29 import time
30 import codecs
31 import traceback
32
33 import re as regexp
34
35 import hippotat.slip as slip
36
37 class DBG(twisted.python.constants.Names):
38   INIT = NamedConstant()
39   CONFIG = NamedConstant()
40   ROUTE = NamedConstant()
41   DROP = NamedConstant()
42   FLOW = NamedConstant()
43   HTTP = NamedConstant()
44   TWISTED = NamedConstant()
45   QUEUE = NamedConstant()
46   HTTP_CTRL = NamedConstant()
47   QUEUE_CTRL = NamedConstant()
48   HTTP_FULL = NamedConstant()
49   CTRL_DUMP = NamedConstant()
50   SLIP_FULL = NamedConstant()
51   DATA_COMPLETE = NamedConstant()
52
53 _hex_codec = codecs.getencoder('hex_codec')
54
55 #---------- logging ----------
56
57 org_stderr = sys.stderr
58
59 log = twisted.logger.Logger()
60
61 debug_set = set()
62 debug_def_detail = DBG.HTTP
63
64 def log_debug(dflag, msg, idof=None, d=None):
65   if dflag not in debug_set: return
66   #print('---------------->',repr((dflag, msg, idof, d)), file=sys.stderr)
67   if idof is not None:
68     msg = '[%#x] %s' % (id(idof), msg)
69   if d is not None:
70     trunc = ''
71     if not DBG.DATA_COMPLETE in debug_set:
72       if len(d) > 64:
73         d = d[0:64]
74         trunc = '...'
75     d = _hex_codec(d)[0].decode('ascii')
76     msg += ' ' + d + trunc
77   log.info('{dflag} {msgcore}', dflag=dflag, msgcore=msg)
78
79 @implementer(twisted.logger.ILogFilterPredicate)
80 class LogNotBoringTwisted:
81   def __call__(self, event):
82     yes = twisted.logger.PredicateResult.yes
83     no  = twisted.logger.PredicateResult.no
84     try:
85       if event.get('log_level') != LogLevel.info:
86         return yes
87       dflag = event.get('dflag')
88       if dflag                         in debug_set: return yes
89       if dflag is None and DBG.TWISTED in debug_set: return yes
90       return no
91     except Exception:
92       print(traceback.format_exc(), file=org_stderr)
93       return yes
94
95 #---------- default config ----------
96
97 defcfg = '''
98 [DEFAULT]
99 #[<client>] overrides
100 max_batch_down = 65536           # used by server, subject to [limits]
101 max_queue_time = 10              # used by server, subject to [limits]
102 target_requests_outstanding = 3  # must match; subject to [limits] on server
103 http_timeout = 30                # used by both } must be
104 http_timeout_grace = 5           # used by both }  compatible
105 max_requests_outstanding = 4     # used by client
106 max_batch_up = 4000              # used by client
107 http_retry = 5                   # used by client
108
109 #[server] or [<client>] overrides
110 ipif = userv root ipif %(local)s,%(peer)s,%(mtu)s,slip %(rnets)s
111 # extra interpolations:  %(local)s        %(peer)s          %(rnet)s
112 #  obtained   on server  [virtual]server  [virtual]relay    [virtual]network
113 #      from   on client  <client>         [virtual]server   [virtual]routes
114
115 [virtual]
116 mtu = 1500
117 routes = ''
118 # network = <prefix>/<len>  # mandatory for server
119 # server  = <ipaddr>   # used by both, default is computed from `network'
120 # relay   = <ipaddr>   # used by server, default from `network' and `server'
121 #  default server is first host in network
122 #  default relay is first host which is not server
123
124 [server]
125 # addrs = 127.0.0.1 ::1    # mandatory for server
126 port = 80                  # used by server
127 # url              # used by client; default from first `addrs' and `port'
128
129 # [<client-ip4-or-ipv6-address>]
130 # password = <password>    # used by both, must match
131
132 [limits]
133 max_batch_down = 262144           # used by server
134 max_queue_time = 121              # used by server
135 http_timeout = 121                # used by server
136 target_requests_outstanding = 10  # used by server
137 '''
138
139 # these need to be defined here so that they can be imported by import *
140 cfg = ConfigParser(strict=False)
141 optparser = OptionParser()
142
143 _mimetrans = bytes.maketrans(b'-'+slip.esc, slip.esc+b'-')
144 def mime_translate(s):
145   # SLIP-encoded packets cannot contain ESC ESC.
146   # Swap `-' and ESC.  The result cannot contain `--'
147   return s.translate(_mimetrans)
148
149 class ConfigResults:
150   def __init__(self, d = { }):
151     self.__dict__ = d
152   def __repr__(self):
153     return 'ConfigResults('+repr(self.__dict__)+')'
154
155 c = ConfigResults()
156
157 def log_discard(packet, iface, saddr, daddr, why):
158   log_debug(DBG.DROP,
159             'discarded packet [%s] %s -> %s: %s' % (iface, saddr, daddr, why),
160             d=packet)
161
162 #---------- packet parsing ----------
163
164 def packet_addrs(packet):
165   version = packet[0] >> 4
166   if version == 4:
167     addrlen = 4
168     saddroff = 3*4
169     factory = ipaddress.IPv4Address
170   elif version == 6:
171     addrlen = 16
172     saddroff = 2*4
173     factory = ipaddress.IPv6Address
174   else:
175     raise ValueError('unsupported IP version %d' % version)
176   saddr = factory(packet[ saddroff           : saddroff + addrlen   ])
177   daddr = factory(packet[ saddroff + addrlen : saddroff + addrlen*2 ])
178   return (saddr, daddr)
179
180 #---------- address handling ----------
181
182 def ipaddr(input):
183   try:
184     r = ipaddress.IPv4Address(input)
185   except AddressValueError:
186     r = ipaddress.IPv6Address(input)
187   return r
188
189 def ipnetwork(input):
190   try:
191     r = ipaddress.IPv4Network(input)
192   except NetworkValueError:
193     r = ipaddress.IPv6Network(input)
194   return r
195
196 #---------- ipif (SLIP) subprocess ----------
197
198 class SlipStreamDecoder():
199   def __init__(self, desc, on_packet):
200     self._buffer = b''
201     self._on_packet = on_packet
202     self._desc = desc
203     self._log('__init__')
204
205   def _log(self, msg, **kwargs):
206     log_debug(DBG.SLIP_FULL, 'slip %s: %s' % (self._desc, msg), **kwargs)
207
208   def inputdata(self, data):
209     self._log('inputdata', d=data)
210     packets = slip.decode(data)
211     packets[0] = self._buffer + packets[0]
212     self._buffer = packets.pop()
213     for packet in packets:
214       self._maybe_packet(packet)
215     self._log('bufremain', d=self._buffer)
216
217   def _maybe_packet(self, packet):
218     self._log('maybepacket', d=packet)
219     if len(packet):
220       self._on_packet(packet)
221
222   def flush(self):
223     self._log('flush')
224     self._maybe_packet(self._buffer)
225     self._buffer = b''
226
227 class _IpifProcessProtocol(twisted.internet.protocol.ProcessProtocol):
228   def __init__(self, router):
229     self._router = router
230     self._decoder = SlipStreamDecoder('ipif', self.slip_on_packet)
231   def connectionMade(self): pass
232   def outReceived(self, data):
233     self._decoder.inputdata(data)
234   def slip_on_packet(self, packet):
235     (saddr, daddr) = packet_addrs(packet)
236     if saddr.is_link_local or daddr.is_link_local:
237       log_discard(packet, 'ipif', saddr, daddr, 'link-local')
238       return
239     self._router(packet, saddr, daddr)
240   def processEnded(self, status):
241     status.raiseException()
242
243 def start_ipif(command, router):
244   global ipif
245   ipif = _IpifProcessProtocol(router)
246   reactor.spawnProcess(ipif,
247                        '/bin/sh',['sh','-xc', command],
248                        childFDs={0:'w', 1:'r', 2:2},
249                        env=None)
250
251 def queue_inbound(packet):
252   log_debug(DBG.FLOW, "queue_inbound", d=packet)
253   ipif.transport.write(slip.delimiter)
254   ipif.transport.write(slip.encode(packet))
255   ipif.transport.write(slip.delimiter)
256
257 #---------- packet queue ----------
258
259 class PacketQueue():
260   def __init__(self, desc, max_queue_time):
261     self._desc = desc
262     assert(desc + '')
263     self._max_queue_time = max_queue_time
264     self._pq = collections.deque() # packets
265
266   def _log(self, dflag, msg, **kwargs):
267     log_debug(dflag, self._desc+' pq: '+msg, **kwargs)
268
269   def append(self, packet):
270     self._log(DBG.QUEUE, 'append', d=packet)
271     self._pq.append((time.monotonic(), packet))
272
273   def nonempty(self):
274     self._log(DBG.QUEUE, 'nonempty ?')
275     while True:
276       try: (queuetime, packet) = self._pq[0]
277       except IndexError:
278         self._log(DBG.QUEUE, 'nonempty ? empty.')
279         return False
280
281       age = time.monotonic() - queuetime
282       if age > self._max_queue_time:
283         # strip old packets off the front
284         self._log(DBG.QUEUE, 'dropping (old)', d=packet)
285         self._pq.popleft()
286         continue
287
288       self._log(DBG.QUEUE, 'nonempty ? nonempty.')
289       return True
290
291   def process(self, sizequery, moredata, max_batch):
292     # sizequery() should return size of batch so far
293     # moredata(s) should add s to batch
294     self._log(DBG.QUEUE, 'process...')
295     while True:
296       try: (dummy, packet) = self._pq[0]
297       except IndexError:
298         self._log(DBG.QUEUE, 'process... empty')
299         break
300
301       self._log(DBG.QUEUE_CTRL, 'process... packet', d=packet)
302
303       encoded = slip.encode(packet)
304       sofar = sizequery()  
305
306       self._log(DBG.QUEUE_CTRL,
307                 'process... (sofar=%d, max=%d) encoded' % (sofar, max_batch),
308                 d=encoded)
309
310       if sofar > 0:
311         if sofar + len(slip.delimiter) + len(encoded) > max_batch:
312           self._log(DBG.QUEUE_CTRL, 'process... overflow')
313           break
314         moredata(slip.delimiter)
315
316       moredata(encoded)
317       self._pq.popleft()
318
319 #---------- error handling ----------
320
321 _crashing = False
322
323 def crash(err):
324   global _crashing
325   _crashing = True
326   print('========== CRASH ==========', err,
327         '===========================', file=sys.stderr)
328   try: reactor.stop()
329   except twisted.internet.error.ReactorNotRunning: pass
330
331 def crash_on_defer(defer):
332   defer.addErrback(lambda err: crash(err))
333
334 def crash_on_critical(event):
335   if event.get('log_level') >= LogLevel.critical:
336     crash(twisted.logger.formatEvent(event))
337
338 #---------- config processing ----------
339
340 def process_cfg_common_always():
341   global mtu
342   c.mtu = cfg.get('virtual','mtu')
343
344 def process_cfg_ipif(section, varmap):
345   for d, s in varmap:
346     try: v = getattr(c, s)
347     except AttributeError: continue
348     setattr(c, d, v)
349
350   #print(repr(c))
351
352   c.ipif_command = cfg.get(section,'ipif', vars=c.__dict__)
353
354 def process_cfg_network():
355   c.network = ipnetwork(cfg.get('virtual','network'))
356   if c.network.num_addresses < 3 + 2:
357     raise ValueError('network needs at least 2^3 addresses')
358
359 def process_cfg_server():
360   try:
361     c.server = cfg.get('virtual','server')
362   except NoOptionError:
363     process_cfg_network()
364     c.server = next(c.network.hosts())
365
366 class ServerAddr():
367   def __init__(self, port, addrspec):
368     self.port = port
369     # also self.addr
370     try:
371       self.addr = ipaddress.IPv4Address(addrspec)
372       self._endpointfactory = twisted.internet.endpoints.TCP4ServerEndpoint
373       self._inurl = b'%s'
374     except AddressValueError:
375       self.addr = ipaddress.IPv6Address(addrspec)
376       self._endpointfactory = twisted.internet.endpoints.TCP6ServerEndpoint
377       self._inurl = b'[%s]'
378   def make_endpoint(self):
379     return self._endpointfactory(reactor, self.port, self.addr)
380   def url(self):
381     url = b'http://' + (self._inurl % str(self.addr).encode('ascii'))
382     if self.port != 80: url += b':%d' % self.port
383     url += b'/'
384     return url
385     
386 def process_cfg_saddrs():
387   try: port = cfg.getint('server','port')
388   except NoOptionError: port = 80
389
390   c.saddrs = [ ]
391   for addrspec in cfg.get('server','addrs').split():
392     sa = ServerAddr(port, addrspec)
393     c.saddrs.append(sa)
394
395 def process_cfg_clients(constructor):
396   c.clients = [ ]
397   for cs in cfg.sections():
398     if not (':' in cs or '.' in cs): continue
399     ci = ipaddr(cs)
400     pw = cfg.get(cs, 'password')
401     pw = pw.encode('utf-8')
402     constructor(ci,cs,pw)
403
404 #---------- startup ----------
405
406 def common_startup():
407   re = regexp.compile('#.*')
408   cfg.read_string(re.sub('', defcfg))
409   need_defcfg = True
410
411   def readconfig(pathname, mandatory=True):
412     def log(m, p=pathname):
413       if not DBG.CONFIG in debug_set: return
414       print('DBG.CONFIG: %s: %s' % (m, pathname))
415
416     try:
417       files = os.listdir(pathname)
418
419     except FileNotFoundError:
420       if mandatory: raise
421       log('skipped')
422       return
423
424     except NotADirectoryError:
425       cfg.read(pathname)
426       log('read file')
427       return
428
429     # is a directory
430     log('directory')
431     re = regexp.compile('[^-A-Za-z0-9_]')
432     for f in os.listdir(cdir):
433       if re.search(f): continue
434       subpath = pathname + '/' + f
435       try:
436         os.stat(subpath)
437       except FileNotFoundError:
438         log('entry skipped', subpath)
439         continue
440       cfg.read(subpath)
441       log('entry read', subpath)
442       
443   def oc_config(od,os, value, op):
444     nonlocal need_defcfg
445     need_defcfg = False
446     readconfig(value)
447
448   def dfs_less_detailed(dl):
449     return [df for df in DBG.iterconstants() if df <= dl]
450
451   def ds_default(od,os,dl,op):
452     global debug_set
453     debug_set = set(dfs_less_detailed(debug_def_detail))
454
455   def ds_select(od,os, spec, op):
456     for it in spec.split(','):
457
458       if it.startswith('-'):
459         mutator = debug_set.discard
460         it = it[1:]
461       else:
462         mutator = debug_set.add
463
464       if it == '+':
465         dfs = DBG.iterconstants()
466
467       else:
468         if it.endswith('+'):
469           mapper = dfs_less_detailed
470           it = it[0:len(it)-1]
471         else:
472           mapper = lambda x: [x]
473
474           try:
475             dfspec = DBG.lookupByName(it)
476           except ValueError:
477             optparser.error('unknown debug flag %s in --debug-select' % it)
478
479         dfs = mapper(dfspec)
480
481       for df in dfs:
482         mutator(df)
483
484   optparser.add_option('-D', '--debug',
485                        nargs=0,
486                        action='callback',
487                        help='enable default debug (to stdout)',
488                        callback= ds_default)
489
490   optparser.add_option('--debug-select',
491                        nargs=1,
492                        type='string',
493                        metavar='[-]DFLAG[+]|[-]+,...',
494                        help=
495 '''enable (`-': disable) each specified DFLAG;
496 `+': do same for all "more interesting" DFLAGSs;
497 just `+': all DFLAGs.
498   DFLAGS: ''' + ' '.join([df.name for df in DBG.iterconstants()]),
499                        action='callback',
500                        callback= ds_select)
501
502   optparser.add_option('-c', '--config',
503                        nargs=1,
504                        type='string',
505                        metavar='CONFIGFILE',
506                        dest='configfile',
507                        action='callback',
508                        callback= oc_config)
509
510   (opts, args) = optparser.parse_args()
511   if len(args): optparser.error('no non-option arguments please')
512
513   if need_defcfg:
514     readconfig('/etc/hippotat/config',   False)
515     readconfig('/etc/hippotat/config.d', False)
516
517   #print(repr(debug_set), file=sys.stderr)
518
519   log_formatter = twisted.logger.formatEventAsClassicLogText
520   stdout_obs = twisted.logger.FileLogObserver(sys.stdout, log_formatter)
521   stderr_obs = twisted.logger.FileLogObserver(sys.stderr, log_formatter)
522   pred = twisted.logger.LogLevelFilterPredicate(LogLevel.error)
523   stdsomething_obs = twisted.logger.FilteringLogObserver(
524     stderr_obs, [pred], stdout_obs
525   )
526   log_observer = twisted.logger.FilteringLogObserver(
527     stdsomething_obs, [LogNotBoringTwisted()]
528   )
529   #log_observer = stdsomething_obs
530   twisted.logger.globalLogBeginner.beginLoggingTo(
531     [ log_observer, crash_on_critical ]
532     )
533
534 def common_run():
535   log_debug(DBG.INIT, 'entering reactor')
536   if not _crashing: reactor.run()
537   print('CRASHED (end)', file=sys.stderr)