chiark / gitweb /
ownsource: tie in (part 1) and cleanup stuff
[hippotat] / hippotatlib / __init__.py
1 # -*- python -*-
2
3 import signal
4 signal.signal(signal.SIGINT, signal.SIG_DFL)
5
6 import sys
7 import os
8
9 from zope.interface import implementer
10
11 import twisted
12 from twisted.internet import reactor
13 import twisted.internet.endpoints
14 import twisted.logger
15 from twisted.logger import LogLevel
16 import twisted.python.constants
17 from twisted.python.constants import NamedConstant
18
19 import ipaddress
20 from ipaddress import AddressValueError
21
22 from optparse import OptionParser
23 import configparser
24 from configparser import ConfigParser
25 from configparser import NoOptionError
26
27 from functools import partial
28
29 import collections
30 import time
31 import codecs
32 import traceback
33
34 import re as regexp
35
36 import hippotatlib.slip as slip
37
38 class DBG(twisted.python.constants.Names):
39   INIT = NamedConstant()
40   CONFIG = NamedConstant()
41   ROUTE = NamedConstant()
42   DROP = NamedConstant()
43   OWNSOURCE = NamedConstant()
44   FLOW = NamedConstant()
45   HTTP = NamedConstant()
46   TWISTED = NamedConstant()
47   QUEUE = NamedConstant()
48   HTTP_CTRL = NamedConstant()
49   QUEUE_CTRL = NamedConstant()
50   HTTP_FULL = NamedConstant()
51   CTRL_DUMP = NamedConstant()
52   SLIP_FULL = NamedConstant()
53   DATA_COMPLETE = NamedConstant()
54
55 _hex_codec = codecs.getencoder('hex_codec')
56
57 #---------- logging ----------
58
59 org_stderr = sys.stderr
60
61 log = twisted.logger.Logger()
62
63 debug_set = set()
64 debug_def_detail = DBG.HTTP
65
66 def log_debug(dflag, msg, idof=None, d=None):
67   if dflag not in debug_set: return
68   #print('---------------->',repr((dflag, msg, idof, d)), file=sys.stderr)
69   if idof is not None:
70     msg = '[%#x] %s' % (id(idof), msg)
71   if d is not None:
72     trunc = ''
73     if not DBG.DATA_COMPLETE in debug_set:
74       if len(d) > 64:
75         d = d[0:64]
76         trunc = '...'
77     d = _hex_codec(d)[0].decode('ascii')
78     msg += ' ' + d + trunc
79   log.info('{dflag} {msgcore}', dflag=dflag, msgcore=msg)
80
81 @implementer(twisted.logger.ILogFilterPredicate)
82 class LogNotBoringTwisted:
83   def __call__(self, event):
84     yes = twisted.logger.PredicateResult.yes
85     no  = twisted.logger.PredicateResult.no
86     try:
87       if event.get('log_level') != LogLevel.info:
88         return yes
89       dflag = event.get('dflag')
90       if dflag is False                            : return yes
91       if dflag                         in debug_set: return yes
92       if dflag is None and DBG.TWISTED in debug_set: return yes
93       return no
94     except Exception:
95       print(traceback.format_exc(), file=org_stderr)
96       return yes
97
98 #---------- default config ----------
99
100 defcfg = '''
101 [DEFAULT]
102 max_batch_down = 65536
103 max_queue_time = 10
104 target_requests_outstanding = 3
105 http_timeout = 30
106 http_timeout_grace = 5
107 max_requests_outstanding = 6
108 max_batch_up = 4000
109 http_retry = 5
110 port = 80
111 vroutes = ''
112
113 #[server] or [<client>] overrides
114 ipif = userv root ipif %(local)s,%(peer)s,%(mtu)s,slip %(rnets)s
115
116 # relating to virtual network
117 mtu = 1500
118
119 [SERVER]
120 server = SERVER
121 # addrs = 127.0.0.1 ::1
122 # url
123
124 # relating to virtual network
125 vvnetwork = 172.24.230.192
126 # vnetwork = <prefix>/<len>
127 # vadd  r  = <ipaddr>
128 # vrelay   = <ipaddr>
129
130
131 # [<client-ip4-or-ipv6-address>]
132 # password = <password>    # used by both, must match
133
134 [LIMIT]
135 max_batch_down = 262144
136 max_queue_time = 121
137 http_timeout = 121
138 target_requests_outstanding = 10
139 '''
140
141 # these need to be defined here so that they can be imported by import *
142 cfg = ConfigParser(strict=False)
143 optparser = OptionParser()
144
145 _mimetrans = bytes.maketrans(b'-'+slip.esc, slip.esc+b'-')
146 def mime_translate(s):
147   # SLIP-encoded packets cannot contain ESC ESC.
148   # Swap `-' and ESC.  The result cannot contain `--'
149   return s.translate(_mimetrans)
150
151 class ConfigResults:
152   def __init__(self):
153     pass
154   def __repr__(self):
155     return 'ConfigResults('+repr(self.__dict__)+')'
156
157 def log_discard(packet, iface, saddr, daddr, why):
158   log_debug(DBG.DROP,
159             'discarded packet [%s] %s -> %s: %s' % (iface, saddr, daddr, why),
160             d=packet)
161
162 #---------- packet parsing ----------
163
164 def packet_addrs(packet):
165   version = packet[0] >> 4
166   if version == 4:
167     addrlen = 4
168     saddroff = 3*4
169     factory = ipaddress.IPv4Address
170   elif version == 6:
171     addrlen = 16
172     saddroff = 2*4
173     factory = ipaddress.IPv6Address
174   else:
175     raise ValueError('unsupported IP version %d' % version)
176   saddr = factory(packet[ saddroff           : saddroff + addrlen   ])
177   daddr = factory(packet[ saddroff + addrlen : saddroff + addrlen*2 ])
178   return (saddr, daddr)
179
180 #---------- address handling ----------
181
182 def ipaddr(input):
183   try:
184     r = ipaddress.IPv4Address(input)
185   except AddressValueError:
186     r = ipaddress.IPv6Address(input)
187   return r
188
189 def ipnetwork(input):
190   try:
191     r = ipaddress.IPv4Network(input)
192   except NetworkValueError:
193     r = ipaddress.IPv6Network(input)
194   return r
195
196 #---------- ipif (SLIP) subprocess ----------
197
198 class SlipStreamDecoder():
199   def __init__(self, desc, on_packet):
200     self._buffer = b''
201     self._on_packet = on_packet
202     self._desc = desc
203     self._log('__init__')
204
205   def _log(self, msg, **kwargs):
206     log_debug(DBG.SLIP_FULL, 'slip %s: %s' % (self._desc, msg), **kwargs)
207
208   def inputdata(self, data):
209     self._log('inputdata', d=data)
210     data = self._buffer + data
211     self._buffer = b''
212     packets = slip.decode(data, True)
213     self._buffer = packets.pop()
214     for packet in packets:
215       self._maybe_packet(packet)
216     self._log('bufremain', d=self._buffer)
217
218   def _maybe_packet(self, packet):
219     self._log('maybepacket', d=packet)
220     if len(packet):
221       self._on_packet(packet)
222
223   def flush(self):
224     self._log('flush')
225     data = self._buffer
226     self._buffer = b''
227     packets = slip.decode(data)
228     assert(len(packets) == 1)
229     self._maybe_packet(packets[0])
230
231 class _IpifProcessProtocol(twisted.internet.protocol.ProcessProtocol):
232   def __init__(self, router):
233     self._router = router
234     self._decoder = SlipStreamDecoder('ipif', self.slip_on_packet)
235   def connectionMade(self): pass
236   def outReceived(self, data):
237     self._decoder.inputdata(data)
238   def slip_on_packet(self, packet):
239     (saddr, daddr) = packet_addrs(packet)
240     if saddr.is_link_local or daddr.is_link_local:
241       log_discard(packet, 'ipif', saddr, daddr, 'link-local')
242       return
243     self._router(packet, saddr, daddr)
244   def processEnded(self, status):
245     status.raiseException()
246
247 def start_ipif(command, router):
248   ipif = _IpifProcessProtocol(router)
249   reactor.spawnProcess(ipif,
250                        '/bin/sh',['sh','-xc', command],
251                        childFDs={0:'w', 1:'r', 2:2},
252                        env=None)
253   return ipif
254
255 def queue_inbound(ipif, packet):
256   log_debug(DBG.FLOW, "queue_inbound", d=packet)
257   ipif.transport.write(slip.delimiter)
258   ipif.transport.write(slip.encode(packet))
259   ipif.transport.write(slip.delimiter)
260
261 #---------- packet queue ----------
262
263 class PacketQueue():
264   def __init__(self, desc, max_queue_time):
265     self._desc = desc
266     assert(desc + '')
267     self._max_queue_time = max_queue_time
268     self._pq = collections.deque() # packets
269
270   def _log(self, dflag, msg, **kwargs):
271     log_debug(dflag, self._desc+' pq: '+msg, **kwargs)
272
273   def append(self, packet):
274     self._log(DBG.QUEUE, 'append', d=packet)
275     self._pq.append((time.monotonic(), packet))
276
277   def nonempty(self):
278     self._log(DBG.QUEUE, 'nonempty ?')
279     while True:
280       try: (queuetime, packet) = self._pq[0]
281       except IndexError:
282         self._log(DBG.QUEUE, 'nonempty ? empty.')
283         return False
284
285       age = time.monotonic() - queuetime
286       if age > self._max_queue_time:
287         # strip old packets off the front
288         self._log(DBG.QUEUE, 'dropping (old)', d=packet)
289         self._pq.popleft()
290         continue
291
292       self._log(DBG.QUEUE, 'nonempty ? nonempty.')
293       return True
294
295   def process(self, sizequery, moredata, max_batch):
296     # sizequery() should return size of batch so far
297     # moredata(s) should add s to batch
298     self._log(DBG.QUEUE, 'process...')
299     while True:
300       try: (dummy, packet) = self._pq[0]
301       except IndexError:
302         self._log(DBG.QUEUE, 'process... empty')
303         break
304
305       self._log(DBG.QUEUE_CTRL, 'process... packet', d=packet)
306
307       encoded = slip.encode(packet)
308       sofar = sizequery()  
309
310       self._log(DBG.QUEUE_CTRL,
311                 'process... (sofar=%d, max=%d) encoded' % (sofar, max_batch),
312                 d=encoded)
313
314       if sofar > 0:
315         if sofar + len(slip.delimiter) + len(encoded) > max_batch:
316           self._log(DBG.QUEUE_CTRL, 'process... overflow')
317           break
318         moredata(slip.delimiter)
319
320       moredata(encoded)
321       self._pq.popleft()
322
323 #---------- error handling ----------
324
325 _crashing = False
326
327 def crash(err):
328   global _crashing
329   _crashing = True
330   print('========== CRASH ==========', err,
331         '===========================', file=sys.stderr)
332   try: reactor.stop()
333   except twisted.internet.error.ReactorNotRunning: pass
334
335 def crash_on_defer(defer):
336   defer.addErrback(lambda err: crash(err))
337
338 def crash_on_critical(event):
339   if event.get('log_level') >= LogLevel.critical:
340     crash(twisted.logger.formatEvent(event))
341
342 #---------- config processing ----------
343
344 def _cfg_process_putatives():
345   servers = { }
346   clients = { }
347   # maps from abstract object to canonical name for cs's
348
349   def putative(cmap, abstract, canoncs):
350     try:
351       current_canoncs = cmap[abstract]
352     except KeyError:
353       pass
354     else:
355       assert(current_canoncs == canoncs)
356     cmap[abstract] = canoncs
357
358   server_pat = r'[-.0-9A-Za-z]+'
359   client_pat = r'[.:0-9a-f]+'
360   server_re = regexp.compile(server_pat)
361   serverclient_re = regexp.compile(server_pat + r' ' + client_pat)
362
363   for cs in cfg.sections():
364     if cs == 'LIMIT':
365       # plan A "[LIMIT]"
366       continue
367
368     try:
369       # plan B "[<client>]" part 1
370       ci = ipaddr(cs)
371     except AddressValueError:
372
373       if server_re.fullmatch(cs):
374         # plan C "[<servername>]"
375         putative(servers, cs, cs)
376         continue
377
378       if serverclient_re.fullmatch(cs):
379         # plan D "[<servername> <client>]" part 1
380         (pss,pcs) = cs.split(' ')
381
382         if pcs == 'LIMIT':
383           # plan E "[<servername> LIMIT]"
384           continue
385
386         try:
387           # plan D "[<servername> <client>]" part 2
388           ci = ipaddr(pc)
389         except AddressValueError:
390           # plan F "[<some thing we do not understand>]"
391           # well, we ignore this
392           print('warning: ignoring config section %s' % cs, file=sys.stderr)
393           continue
394
395         else: # no AddressValueError
396           # plan D "[<servername> <client]" part 3
397           putative(clients, ci, pcs)
398           putative(servers, pss, pss)
399           continue
400
401     else: # no AddressValueError
402       # plan B "[<client>" part 2
403       putative(clients, ci, cs)
404       continue
405
406   return (servers, clients)
407
408 def cfg_process_common(c, ss):
409   c.mtu = cfg.getint(ss, 'mtu')
410
411 def cfg_process_saddrs(c, ss):
412   class ServerAddr():
413     def __init__(self, port, addrspec):
414       self.port = port
415       # also self.addr
416       try:
417         self.addr = ipaddress.IPv4Address(addrspec)
418         self._endpointfactory = twisted.internet.endpoints.TCP4ServerEndpoint
419         self._inurl = b'%s'
420       except AddressValueError:
421         self.addr = ipaddress.IPv6Address(addrspec)
422         self._endpointfactory = twisted.internet.endpoints.TCP6ServerEndpoint
423         self._inurl = b'[%s]'
424     def make_endpoint(self):
425       return self._endpointfactory(reactor, self.port, self.addr)
426     def url(self):
427       url = b'http://' + (self._inurl % str(self.addr).encode('ascii'))
428       if self.port != 80: url += b':%d' % self.port
429       url += b'/'
430       return url
431
432   c.port = cfg.getint(ss,'port')
433   c.saddrs = [ ]
434   for addrspec in cfg.get(ss, 'addrs').split():
435     sa = ServerAddr(c.port, addrspec)
436     c.saddrs.append(sa)
437
438 def cfg_process_vnetwork(c, ss):
439   c.vnetwork = ipnetwork(cfg.get(ss,'vnetwork'))
440   if c.vnetwork.num_addresses < 3 + 2:
441     raise ValueError('vnetwork needs at least 2^3 addresses')
442
443 def cfg_process_vaddr(c, ss):
444   try:
445     c.vaddr = cfg.get(ss,'vaddr')
446   except NoOptionError:
447     cfg_process_vnetwork(c, ss)
448     c.vaddr = next(c.vnetwork.hosts())
449
450 def cfg_search_section(key,sections):
451   for section in sections:
452     if cfg.has_option(section, key):
453       return section
454   raise NoOptionError(key, repr(sections))
455
456 def cfg_search(getter,key,sections):
457   section = cfg_search_section(key,sections)
458   return getter(section, key)
459
460 def cfg_process_client_limited(cc,ss,sections,key):
461   val = cfg_search(cfg.getint, key, sections)
462   lim = cfg_search(cfg.getint, key, ['%s LIMIT' % ss, 'LIMIT'])
463   cc.__dict__[key] = min(val,lim)
464
465 def cfg_process_client_common(cc,ss,cs,ci):
466   # returns sections to search in, iff password is defined, otherwise None
467   cc.ci = ci
468
469   sections = ['%s %s' % (ss,cs),
470               cs,
471               ss,
472               'DEFAULT']
473
474   try: pwsection = cfg_search_section('password', sections)
475   except NoOptionError: return None
476     
477   pw = cfg.get(pwsection, 'password')
478   cc.password = pw.encode('utf-8')
479
480   cfg_process_client_limited(cc,ss,sections,'target_requests_outstanding')
481   cfg_process_client_limited(cc,ss,sections,'http_timeout')
482
483   return sections
484
485 def cfg_process_ipif(c, sections, varmap):
486   for d, s in varmap:
487     try: v = getattr(c, s)
488     except AttributeError: continue
489     setattr(c, d, v)
490
491   #print('CFGIPIF',repr((varmap, sections, c.__dict__)),file=sys.stderr)
492
493   section = cfg_search_section('ipif', sections)
494   c.ipif_command = cfg.get(section,'ipif', vars=c.__dict__)
495
496 #---------- startup ----------
497
498 def common_startup(process_cfg):
499   # calls process_cfg(putative_clients, putative_servers)
500
501   # ConfigParser hates #-comments after values
502   trailingcomments_re = regexp.compile(r'#.*')
503   cfg.read_string(trailingcomments_re.sub('', defcfg))
504   need_defcfg = True
505
506   def readconfig(pathname, mandatory=True):
507     def log(m, p=pathname):
508       if not DBG.CONFIG in debug_set: return
509       print('DBG.CONFIG: %s: %s' % (m, pathname))
510
511     try:
512       files = os.listdir(pathname)
513
514     except FileNotFoundError:
515       if mandatory: raise
516       log('skipped')
517       return
518
519     except NotADirectoryError:
520       cfg.read(pathname)
521       log('read file')
522       return
523
524     # is a directory
525     log('directory')
526     re = regexp.compile('[^-A-Za-z0-9_]')
527     for f in os.listdir(cdir):
528       if re.search(f): continue
529       subpath = pathname + '/' + f
530       try:
531         os.stat(subpath)
532       except FileNotFoundError:
533         log('entry skipped', subpath)
534         continue
535       cfg.read(subpath)
536       log('entry read', subpath)
537       
538   def oc_config(od,os, value, op):
539     nonlocal need_defcfg
540     need_defcfg = False
541     readconfig(value)
542
543   def dfs_less_detailed(dl):
544     return [df for df in DBG.iterconstants() if df <= dl]
545
546   def ds_default(od,os,dl,op):
547     global debug_set
548     debug_set = set(dfs_less_detailed(debug_def_detail))
549
550   def ds_select(od,os, spec, op):
551     for it in spec.split(','):
552
553       if it.startswith('-'):
554         mutator = debug_set.discard
555         it = it[1:]
556       else:
557         mutator = debug_set.add
558
559       if it == '+':
560         dfs = DBG.iterconstants()
561
562       else:
563         if it.endswith('+'):
564           mapper = dfs_less_detailed
565           it = it[0:len(it)-1]
566         else:
567           mapper = lambda x: [x]
568
569           try:
570             dfspec = DBG.lookupByName(it)
571           except ValueError:
572             optparser.error('unknown debug flag %s in --debug-select' % it)
573
574         dfs = mapper(dfspec)
575
576       for df in dfs:
577         mutator(df)
578
579   optparser.add_option('-D', '--debug',
580                        nargs=0,
581                        action='callback',
582                        help='enable default debug (to stdout)',
583                        callback= ds_default)
584
585   optparser.add_option('--debug-select',
586                        nargs=1,
587                        type='string',
588                        metavar='[-]DFLAG[+]|[-]+,...',
589                        help=
590 '''enable (`-': disable) each specified DFLAG;
591 `+': do same for all "more interesting" DFLAGSs;
592 just `+': all DFLAGs.
593   DFLAGS: ''' + ' '.join([df.name for df in DBG.iterconstants()]),
594                        action='callback',
595                        callback= ds_select)
596
597   optparser.add_option('-c', '--config',
598                        nargs=1,
599                        type='string',
600                        metavar='CONFIGFILE',
601                        dest='configfile',
602                        action='callback',
603                        callback= oc_config)
604
605   (opts, args) = optparser.parse_args()
606   if len(args): optparser.error('no non-option arguments please')
607
608   if need_defcfg:
609     readconfig('/etc/hippotat/config',   False)
610     readconfig('/etc/hippotat/config.d', False)
611
612   try:
613     (pss, pcs) = _cfg_process_putatives()
614     process_cfg(pss, pcs)
615   except (configparser.Error, ValueError):
616     traceback.print_exc(file=sys.stderr)
617     print('\nInvalid configuration, giving up.', file=sys.stderr)
618     sys.exit(12)
619
620   #print(repr(debug_set), file=sys.stderr)
621
622   log_formatter = twisted.logger.formatEventAsClassicLogText
623   stdout_obs = twisted.logger.FileLogObserver(sys.stdout, log_formatter)
624   stderr_obs = twisted.logger.FileLogObserver(sys.stderr, log_formatter)
625   pred = twisted.logger.LogLevelFilterPredicate(LogLevel.error)
626   stdsomething_obs = twisted.logger.FilteringLogObserver(
627     stderr_obs, [pred], stdout_obs
628   )
629   log_observer = twisted.logger.FilteringLogObserver(
630     stdsomething_obs, [LogNotBoringTwisted()]
631   )
632   #log_observer = stdsomething_obs
633   twisted.logger.globalLogBeginner.beginLoggingTo(
634     [ log_observer, crash_on_critical ]
635     )
636
637 def common_run():
638   log_debug(DBG.INIT, 'entering reactor')
639   if not _crashing: reactor.run()
640   print('CRASHED (end)', file=sys.stderr)