chiark / gitweb /
ownsource: logging etc.
[hippotat.git] / hippotatlib / __init__.py
1 # -*- python -*-
2
3 import signal
4 signal.signal(signal.SIGINT, signal.SIG_DFL)
5
6 import sys
7 import os
8
9 from zope.interface import implementer
10
11 import twisted
12 from twisted.internet import reactor
13 import twisted.internet.endpoints
14 import twisted.logger
15 from twisted.logger import LogLevel
16 import twisted.python.constants
17 from twisted.python.constants import NamedConstant
18
19 import ipaddress
20 from ipaddress import AddressValueError
21
22 from optparse import OptionParser
23 import configparser
24 from configparser import ConfigParser
25 from configparser import NoOptionError
26
27 from functools import partial
28
29 import collections
30 import time
31 import codecs
32 import traceback
33
34 import re as regexp
35
36 import hippotatlib.slip as slip
37
38 class DBG(twisted.python.constants.Names):
39   INIT = NamedConstant()
40   CONFIG = NamedConstant()
41   ROUTE = NamedConstant()
42   DROP = NamedConstant()
43   FLOW = NamedConstant()
44   HTTP = NamedConstant()
45   TWISTED = NamedConstant()
46   QUEUE = NamedConstant()
47   HTTP_CTRL = NamedConstant()
48   QUEUE_CTRL = NamedConstant()
49   HTTP_FULL = NamedConstant()
50   CTRL_DUMP = NamedConstant()
51   SLIP_FULL = NamedConstant()
52   DATA_COMPLETE = NamedConstant()
53
54 _hex_codec = codecs.getencoder('hex_codec')
55
56 #---------- logging ----------
57
58 org_stderr = sys.stderr
59
60 log = twisted.logger.Logger()
61
62 debug_set = set()
63 debug_def_detail = DBG.HTTP
64
65 def log_debug(dflag, msg, idof=None, d=None):
66   if dflag not in debug_set: return
67   #print('---------------->',repr((dflag, msg, idof, d)), file=sys.stderr)
68   if idof is not None:
69     msg = '[%#x] %s' % (id(idof), msg)
70   if d is not None:
71     trunc = ''
72     if not DBG.DATA_COMPLETE in debug_set:
73       if len(d) > 64:
74         d = d[0:64]
75         trunc = '...'
76     d = _hex_codec(d)[0].decode('ascii')
77     msg += ' ' + d + trunc
78   log.info('{dflag} {msgcore}', dflag=dflag, msgcore=msg)
79
80 @implementer(twisted.logger.ILogFilterPredicate)
81 class LogNotBoringTwisted:
82   def __call__(self, event):
83     yes = twisted.logger.PredicateResult.yes
84     no  = twisted.logger.PredicateResult.no
85     try:
86       if event.get('log_level') != LogLevel.info:
87         return yes
88       dflag = event.get('dflag')
89       if dflag is False                            : return yes
90       if dflag                         in debug_set: return yes
91       if dflag is None and DBG.TWISTED in debug_set: return yes
92       return no
93     except Exception:
94       print(traceback.format_exc(), file=org_stderr)
95       return yes
96
97 #---------- default config ----------
98
99 defcfg = '''
100 [DEFAULT]
101 max_batch_down = 65536
102 max_queue_time = 10
103 target_requests_outstanding = 3
104 http_timeout = 30
105 http_timeout_grace = 5
106 max_requests_outstanding = 6
107 max_batch_up = 4000
108 http_retry = 5
109 port = 80
110 vroutes = ''
111
112 #[server] or [<client>] overrides
113 ipif = userv root ipif %(local)s,%(peer)s,%(mtu)s,slip %(rnets)s
114
115 # relating to virtual network
116 mtu = 1500
117
118 [SERVER]
119 server = SERVER
120 # addrs = 127.0.0.1 ::1
121 # url
122
123 # relating to virtual network
124 vvnetwork = 172.24.230.192
125 # vnetwork = <prefix>/<len>
126 # vadd  r  = <ipaddr>
127 # vrelay   = <ipaddr>
128
129
130 # [<client-ip4-or-ipv6-address>]
131 # password = <password>    # used by both, must match
132
133 [LIMIT]
134 max_batch_down = 262144
135 max_queue_time = 121
136 http_timeout = 121
137 target_requests_outstanding = 10
138 '''
139
140 # these need to be defined here so that they can be imported by import *
141 cfg = ConfigParser(strict=False)
142 optparser = OptionParser()
143
144 _mimetrans = bytes.maketrans(b'-'+slip.esc, slip.esc+b'-')
145 def mime_translate(s):
146   # SLIP-encoded packets cannot contain ESC ESC.
147   # Swap `-' and ESC.  The result cannot contain `--'
148   return s.translate(_mimetrans)
149
150 class ConfigResults:
151   def __init__(self):
152     pass
153   def __repr__(self):
154     return 'ConfigResults('+repr(self.__dict__)+')'
155
156 def log_discard(packet, iface, saddr, daddr, why):
157   log_debug(DBG.DROP,
158             'discarded packet [%s] %s -> %s: %s' % (iface, saddr, daddr, why),
159             d=packet)
160
161 #---------- packet parsing ----------
162
163 def packet_addrs(packet):
164   version = packet[0] >> 4
165   if version == 4:
166     addrlen = 4
167     saddroff = 3*4
168     factory = ipaddress.IPv4Address
169   elif version == 6:
170     addrlen = 16
171     saddroff = 2*4
172     factory = ipaddress.IPv6Address
173   else:
174     raise ValueError('unsupported IP version %d' % version)
175   saddr = factory(packet[ saddroff           : saddroff + addrlen   ])
176   daddr = factory(packet[ saddroff + addrlen : saddroff + addrlen*2 ])
177   return (saddr, daddr)
178
179 #---------- address handling ----------
180
181 def ipaddr(input):
182   try:
183     r = ipaddress.IPv4Address(input)
184   except AddressValueError:
185     r = ipaddress.IPv6Address(input)
186   return r
187
188 def ipnetwork(input):
189   try:
190     r = ipaddress.IPv4Network(input)
191   except NetworkValueError:
192     r = ipaddress.IPv6Network(input)
193   return r
194
195 #---------- ipif (SLIP) subprocess ----------
196
197 class SlipStreamDecoder():
198   def __init__(self, desc, on_packet):
199     self._buffer = b''
200     self._on_packet = on_packet
201     self._desc = desc
202     self._log('__init__')
203
204   def _log(self, msg, **kwargs):
205     log_debug(DBG.SLIP_FULL, 'slip %s: %s' % (self._desc, msg), **kwargs)
206
207   def inputdata(self, data):
208     self._log('inputdata', d=data)
209     data = self._buffer + data
210     self._buffer = b''
211     packets = slip.decode(data, True)
212     self._buffer = packets.pop()
213     for packet in packets:
214       self._maybe_packet(packet)
215     self._log('bufremain', d=self._buffer)
216
217   def _maybe_packet(self, packet):
218     self._log('maybepacket', d=packet)
219     if len(packet):
220       self._on_packet(packet)
221
222   def flush(self):
223     self._log('flush')
224     data = self._buffer
225     self._buffer = b''
226     packets = slip.decode(data)
227     assert(len(packets) == 1)
228     self._maybe_packet(packets[0])
229
230 class _IpifProcessProtocol(twisted.internet.protocol.ProcessProtocol):
231   def __init__(self, router):
232     self._router = router
233     self._decoder = SlipStreamDecoder('ipif', self.slip_on_packet)
234   def connectionMade(self): pass
235   def outReceived(self, data):
236     self._decoder.inputdata(data)
237   def slip_on_packet(self, packet):
238     (saddr, daddr) = packet_addrs(packet)
239     if saddr.is_link_local or daddr.is_link_local:
240       log_discard(packet, 'ipif', saddr, daddr, 'link-local')
241       return
242     self._router(packet, saddr, daddr)
243   def processEnded(self, status):
244     status.raiseException()
245
246 def start_ipif(command, router):
247   ipif = _IpifProcessProtocol(router)
248   reactor.spawnProcess(ipif,
249                        '/bin/sh',['sh','-xc', command],
250                        childFDs={0:'w', 1:'r', 2:2},
251                        env=None)
252   return ipif
253
254 def queue_inbound(ipif, packet):
255   log_debug(DBG.FLOW, "queue_inbound", d=packet)
256   ipif.transport.write(slip.delimiter)
257   ipif.transport.write(slip.encode(packet))
258   ipif.transport.write(slip.delimiter)
259
260 #---------- packet queue ----------
261
262 class PacketQueue():
263   def __init__(self, desc, max_queue_time):
264     self._desc = desc
265     assert(desc + '')
266     self._max_queue_time = max_queue_time
267     self._pq = collections.deque() # packets
268
269   def _log(self, dflag, msg, **kwargs):
270     log_debug(dflag, self._desc+' pq: '+msg, **kwargs)
271
272   def append(self, packet):
273     self._log(DBG.QUEUE, 'append', d=packet)
274     self._pq.append((time.monotonic(), packet))
275
276   def nonempty(self):
277     self._log(DBG.QUEUE, 'nonempty ?')
278     while True:
279       try: (queuetime, packet) = self._pq[0]
280       except IndexError:
281         self._log(DBG.QUEUE, 'nonempty ? empty.')
282         return False
283
284       age = time.monotonic() - queuetime
285       if age > self._max_queue_time:
286         # strip old packets off the front
287         self._log(DBG.QUEUE, 'dropping (old)', d=packet)
288         self._pq.popleft()
289         continue
290
291       self._log(DBG.QUEUE, 'nonempty ? nonempty.')
292       return True
293
294   def process(self, sizequery, moredata, max_batch):
295     # sizequery() should return size of batch so far
296     # moredata(s) should add s to batch
297     self._log(DBG.QUEUE, 'process...')
298     while True:
299       try: (dummy, packet) = self._pq[0]
300       except IndexError:
301         self._log(DBG.QUEUE, 'process... empty')
302         break
303
304       self._log(DBG.QUEUE_CTRL, 'process... packet', d=packet)
305
306       encoded = slip.encode(packet)
307       sofar = sizequery()  
308
309       self._log(DBG.QUEUE_CTRL,
310                 'process... (sofar=%d, max=%d) encoded' % (sofar, max_batch),
311                 d=encoded)
312
313       if sofar > 0:
314         if sofar + len(slip.delimiter) + len(encoded) > max_batch:
315           self._log(DBG.QUEUE_CTRL, 'process... overflow')
316           break
317         moredata(slip.delimiter)
318
319       moredata(encoded)
320       self._pq.popleft()
321
322 #---------- error handling ----------
323
324 _crashing = False
325
326 def crash(err):
327   global _crashing
328   _crashing = True
329   print('========== CRASH ==========', err,
330         '===========================', file=sys.stderr)
331   try: reactor.stop()
332   except twisted.internet.error.ReactorNotRunning: pass
333
334 def crash_on_defer(defer):
335   defer.addErrback(lambda err: crash(err))
336
337 def crash_on_critical(event):
338   if event.get('log_level') >= LogLevel.critical:
339     crash(twisted.logger.formatEvent(event))
340
341 #---------- config processing ----------
342
343 def _cfg_process_putatives():
344   servers = { }
345   clients = { }
346   # maps from abstract object to canonical name for cs's
347
348   def putative(cmap, abstract, canoncs):
349     try:
350       current_canoncs = cmap[abstract]
351     except KeyError:
352       pass
353     else:
354       assert(current_canoncs == canoncs)
355     cmap[abstract] = canoncs
356
357   server_pat = r'[-.0-9A-Za-z]+'
358   client_pat = r'[.:0-9a-f]+'
359   server_re = regexp.compile(server_pat)
360   serverclient_re = regexp.compile(server_pat + r' ' + client_pat)
361
362   for cs in cfg.sections():
363     if cs == 'LIMIT':
364       # plan A "[LIMIT]"
365       continue
366
367     try:
368       # plan B "[<client>]" part 1
369       ci = ipaddr(cs)
370     except AddressValueError:
371
372       if server_re.fullmatch(cs):
373         # plan C "[<servername>]"
374         putative(servers, cs, cs)
375         continue
376
377       if serverclient_re.fullmatch(cs):
378         # plan D "[<servername> <client>]" part 1
379         (pss,pcs) = cs.split(' ')
380
381         if pcs == 'LIMIT':
382           # plan E "[<servername> LIMIT]"
383           continue
384
385         try:
386           # plan D "[<servername> <client>]" part 2
387           ci = ipaddr(pc)
388         except AddressValueError:
389           # plan F "[<some thing we do not understand>]"
390           # well, we ignore this
391           print('warning: ignoring config section %s' % cs, file=sys.stderr)
392           continue
393
394         else: # no AddressValueError
395           # plan D "[<servername> <client]" part 3
396           putative(clients, ci, pcs)
397           putative(servers, pss, pss)
398           continue
399
400     else: # no AddressValueError
401       # plan B "[<client>" part 2
402       putative(clients, ci, cs)
403       continue
404
405   return (servers, clients)
406
407 def cfg_process_common(c, ss):
408   c.mtu = cfg.getint(ss, 'mtu')
409
410 def cfg_process_saddrs(c, ss):
411   class ServerAddr():
412     def __init__(self, port, addrspec):
413       self.port = port
414       # also self.addr
415       try:
416         self.addr = ipaddress.IPv4Address(addrspec)
417         self._endpointfactory = twisted.internet.endpoints.TCP4ServerEndpoint
418         self._inurl = b'%s'
419       except AddressValueError:
420         self.addr = ipaddress.IPv6Address(addrspec)
421         self._endpointfactory = twisted.internet.endpoints.TCP6ServerEndpoint
422         self._inurl = b'[%s]'
423     def make_endpoint(self):
424       return self._endpointfactory(reactor, self.port, self.addr)
425     def url(self):
426       url = b'http://' + (self._inurl % str(self.addr).encode('ascii'))
427       if self.port != 80: url += b':%d' % self.port
428       url += b'/'
429       return url
430
431   c.port = cfg.getint(ss,'port')
432   c.saddrs = [ ]
433   for addrspec in cfg.get(ss, 'addrs').split():
434     sa = ServerAddr(c.port, addrspec)
435     c.saddrs.append(sa)
436
437 def cfg_process_vnetwork(c, ss):
438   c.vnetwork = ipnetwork(cfg.get(ss,'vnetwork'))
439   if c.vnetwork.num_addresses < 3 + 2:
440     raise ValueError('vnetwork needs at least 2^3 addresses')
441
442 def cfg_process_vaddr(c, ss):
443   try:
444     c.vaddr = cfg.get(ss,'vaddr')
445   except NoOptionError:
446     cfg_process_vnetwork(c, ss)
447     c.vaddr = next(c.vnetwork.hosts())
448
449 def cfg_search_section(key,sections):
450   for section in sections:
451     if cfg.has_option(section, key):
452       return section
453   raise NoOptionError(key, repr(sections))
454
455 def cfg_search(getter,key,sections):
456   section = cfg_search_section(key,sections)
457   return getter(section, key)
458
459 def cfg_process_client_limited(cc,ss,sections,key):
460   val = cfg_search(cfg.getint, key, sections)
461   lim = cfg_search(cfg.getint, key, ['%s LIMIT' % ss, 'LIMIT'])
462   cc.__dict__[key] = min(val,lim)
463
464 def cfg_process_client_common(cc,ss,cs,ci):
465   # returns sections to search in, iff password is defined, otherwise None
466   cc.ci = ci
467
468   sections = ['%s %s' % (ss,cs),
469               cs,
470               ss,
471               'DEFAULT']
472
473   try: pwsection = cfg_search_section('password', sections)
474   except NoOptionError: return None
475     
476   pw = cfg.get(pwsection, 'password')
477   cc.password = pw.encode('utf-8')
478
479   cfg_process_client_limited(cc,ss,sections,'target_requests_outstanding')
480   cfg_process_client_limited(cc,ss,sections,'http_timeout')
481
482   return sections
483
484 def cfg_process_ipif(c, sections, varmap):
485   for d, s in varmap:
486     try: v = getattr(c, s)
487     except AttributeError: continue
488     setattr(c, d, v)
489
490   #print('CFGIPIF',repr((varmap, sections, c.__dict__)),file=sys.stderr)
491
492   section = cfg_search_section('ipif', sections)
493   c.ipif_command = cfg.get(section,'ipif', vars=c.__dict__)
494
495 #---------- startup ----------
496
497 def common_startup(process_cfg):
498   # calls process_cfg(putative_clients, putative_servers)
499
500   # ConfigParser hates #-comments after values
501   trailingcomments_re = regexp.compile(r'#.*')
502   cfg.read_string(trailingcomments_re.sub('', defcfg))
503   need_defcfg = True
504
505   def readconfig(pathname, mandatory=True):
506     def log(m, p=pathname):
507       if not DBG.CONFIG in debug_set: return
508       print('DBG.CONFIG: %s: %s' % (m, pathname))
509
510     try:
511       files = os.listdir(pathname)
512
513     except FileNotFoundError:
514       if mandatory: raise
515       log('skipped')
516       return
517
518     except NotADirectoryError:
519       cfg.read(pathname)
520       log('read file')
521       return
522
523     # is a directory
524     log('directory')
525     re = regexp.compile('[^-A-Za-z0-9_]')
526     for f in os.listdir(cdir):
527       if re.search(f): continue
528       subpath = pathname + '/' + f
529       try:
530         os.stat(subpath)
531       except FileNotFoundError:
532         log('entry skipped', subpath)
533         continue
534       cfg.read(subpath)
535       log('entry read', subpath)
536       
537   def oc_config(od,os, value, op):
538     nonlocal need_defcfg
539     need_defcfg = False
540     readconfig(value)
541
542   def dfs_less_detailed(dl):
543     return [df for df in DBG.iterconstants() if df <= dl]
544
545   def ds_default(od,os,dl,op):
546     global debug_set
547     debug_set = set(dfs_less_detailed(debug_def_detail))
548
549   def ds_select(od,os, spec, op):
550     for it in spec.split(','):
551
552       if it.startswith('-'):
553         mutator = debug_set.discard
554         it = it[1:]
555       else:
556         mutator = debug_set.add
557
558       if it == '+':
559         dfs = DBG.iterconstants()
560
561       else:
562         if it.endswith('+'):
563           mapper = dfs_less_detailed
564           it = it[0:len(it)-1]
565         else:
566           mapper = lambda x: [x]
567
568           try:
569             dfspec = DBG.lookupByName(it)
570           except ValueError:
571             optparser.error('unknown debug flag %s in --debug-select' % it)
572
573         dfs = mapper(dfspec)
574
575       for df in dfs:
576         mutator(df)
577
578   optparser.add_option('-D', '--debug',
579                        nargs=0,
580                        action='callback',
581                        help='enable default debug (to stdout)',
582                        callback= ds_default)
583
584   optparser.add_option('--debug-select',
585                        nargs=1,
586                        type='string',
587                        metavar='[-]DFLAG[+]|[-]+,...',
588                        help=
589 '''enable (`-': disable) each specified DFLAG;
590 `+': do same for all "more interesting" DFLAGSs;
591 just `+': all DFLAGs.
592   DFLAGS: ''' + ' '.join([df.name for df in DBG.iterconstants()]),
593                        action='callback',
594                        callback= ds_select)
595
596   optparser.add_option('-c', '--config',
597                        nargs=1,
598                        type='string',
599                        metavar='CONFIGFILE',
600                        dest='configfile',
601                        action='callback',
602                        callback= oc_config)
603
604   (opts, args) = optparser.parse_args()
605   if len(args): optparser.error('no non-option arguments please')
606
607   if need_defcfg:
608     readconfig('/etc/hippotat/config',   False)
609     readconfig('/etc/hippotat/config.d', False)
610
611   try:
612     (pss, pcs) = _cfg_process_putatives()
613     process_cfg(pss, pcs)
614   except (configparser.Error, ValueError):
615     traceback.print_exc(file=sys.stderr)
616     print('\nInvalid configuration, giving up.', file=sys.stderr)
617     sys.exit(12)
618
619   #print(repr(debug_set), file=sys.stderr)
620
621   log_formatter = twisted.logger.formatEventAsClassicLogText
622   stdout_obs = twisted.logger.FileLogObserver(sys.stdout, log_formatter)
623   stderr_obs = twisted.logger.FileLogObserver(sys.stderr, log_formatter)
624   pred = twisted.logger.LogLevelFilterPredicate(LogLevel.error)
625   stdsomething_obs = twisted.logger.FilteringLogObserver(
626     stderr_obs, [pred], stdout_obs
627   )
628   log_observer = twisted.logger.FilteringLogObserver(
629     stdsomething_obs, [LogNotBoringTwisted()]
630   )
631   #log_observer = stdsomething_obs
632   twisted.logger.globalLogBeginner.beginLoggingTo(
633     [ log_observer, crash_on_critical ]
634     )
635
636 def common_run():
637   log_debug(DBG.INIT, 'entering reactor')
638   if not _crashing: reactor.run()
639   print('CRASHED (end)', file=sys.stderr)