chiark / gitweb /
9f72b48abcee4f225cb0917dbf5dcece16450a0e
[hippotat.git] / hippotat / __init__.py
1 # -*- python -*-
2
3 import signal
4 signal.signal(signal.SIGINT, signal.SIG_DFL)
5
6 import sys
7 import os
8
9 from zope.interface import implementer
10
11 import twisted
12 from twisted.internet import reactor
13 import twisted.internet.endpoints
14 import twisted.logger
15 from twisted.logger import LogLevel
16 import twisted.python.constants
17 from twisted.python.constants import NamedConstant
18
19 import ipaddress
20 from ipaddress import AddressValueError
21
22 from optparse import OptionParser
23 import configparser
24 from configparser import ConfigParser
25 from configparser import NoOptionError
26
27 from functools import partial
28
29 import collections
30 import time
31 import codecs
32 import traceback
33
34 import re as regexp
35
36 import hippotat.slip as slip
37
38 class DBG(twisted.python.constants.Names):
39   INIT = NamedConstant()
40   CONFIG = NamedConstant()
41   ROUTE = NamedConstant()
42   DROP = NamedConstant()
43   FLOW = NamedConstant()
44   HTTP = NamedConstant()
45   TWISTED = NamedConstant()
46   QUEUE = NamedConstant()
47   HTTP_CTRL = NamedConstant()
48   QUEUE_CTRL = NamedConstant()
49   HTTP_FULL = NamedConstant()
50   CTRL_DUMP = NamedConstant()
51   SLIP_FULL = NamedConstant()
52   DATA_COMPLETE = NamedConstant()
53
54 _hex_codec = codecs.getencoder('hex_codec')
55
56 #---------- logging ----------
57
58 org_stderr = sys.stderr
59
60 log = twisted.logger.Logger()
61
62 debug_set = set()
63 debug_def_detail = DBG.HTTP
64
65 def log_debug(dflag, msg, idof=None, d=None):
66   if dflag not in debug_set: return
67   #print('---------------->',repr((dflag, msg, idof, d)), file=sys.stderr)
68   if idof is not None:
69     msg = '[%#x] %s' % (id(idof), msg)
70   if d is not None:
71     trunc = ''
72     if not DBG.DATA_COMPLETE in debug_set:
73       if len(d) > 64:
74         d = d[0:64]
75         trunc = '...'
76     d = _hex_codec(d)[0].decode('ascii')
77     msg += ' ' + d + trunc
78   log.info('{dflag} {msgcore}', dflag=dflag, msgcore=msg)
79
80 @implementer(twisted.logger.ILogFilterPredicate)
81 class LogNotBoringTwisted:
82   def __call__(self, event):
83     yes = twisted.logger.PredicateResult.yes
84     no  = twisted.logger.PredicateResult.no
85     try:
86       if event.get('log_level') != LogLevel.info:
87         return yes
88       dflag = event.get('dflag')
89       if dflag is False                            : return yes
90       if dflag                         in debug_set: return yes
91       if dflag is None and DBG.TWISTED in debug_set: return yes
92       return no
93     except Exception:
94       print(traceback.format_exc(), file=org_stderr)
95       return yes
96
97 #---------- default config ----------
98
99 defcfg = '''
100 [DEFAULT]
101 max_batch_down = 65536
102 max_queue_time = 10
103 target_requests_outstanding = 3
104 http_timeout = 30
105 http_timeout_grace = 5
106 max_requests_outstanding = 6
107 max_batch_up = 4000
108 http_retry = 5
109 port = 80
110 vroutes = ''
111
112 #[server] or [<client>] overrides
113 ipif = userv root ipif %(local)s,%(peer)s,%(mtu)s,slip %(rnets)s
114
115 # relating to virtual network
116 mtu = 1500
117
118 [SERVER]
119 server = SERVER
120 # addrs = 127.0.0.1 ::1
121 # url
122
123 # relating to virtual network
124 vvnetwork = 172.24.230.192
125 # vnetwork = <prefix>/<len>
126 # vadd  r  = <ipaddr>
127 # vrelay   = <ipaddr>
128
129
130 # [<client-ip4-or-ipv6-address>]
131 # password = <password>    # used by both, must match
132
133 [LIMIT]
134 max_batch_down = 262144
135 max_queue_time = 121
136 http_timeout = 121
137 target_requests_outstanding = 10
138 '''
139
140 # these need to be defined here so that they can be imported by import *
141 cfg = ConfigParser(strict=False)
142 optparser = OptionParser()
143
144 _mimetrans = bytes.maketrans(b'-'+slip.esc, slip.esc+b'-')
145 def mime_translate(s):
146   # SLIP-encoded packets cannot contain ESC ESC.
147   # Swap `-' and ESC.  The result cannot contain `--'
148   return s.translate(_mimetrans)
149
150 class ConfigResults:
151   def __init__(self):
152     pass
153   def __repr__(self):
154     return 'ConfigResults('+repr(self.__dict__)+')'
155
156 def log_discard(packet, iface, saddr, daddr, why):
157   log_debug(DBG.DROP,
158             'discarded packet [%s] %s -> %s: %s' % (iface, saddr, daddr, why),
159             d=packet)
160
161 #---------- packet parsing ----------
162
163 def packet_addrs(packet):
164   version = packet[0] >> 4
165   if version == 4:
166     addrlen = 4
167     saddroff = 3*4
168     factory = ipaddress.IPv4Address
169   elif version == 6:
170     addrlen = 16
171     saddroff = 2*4
172     factory = ipaddress.IPv6Address
173   else:
174     raise ValueError('unsupported IP version %d' % version)
175   saddr = factory(packet[ saddroff           : saddroff + addrlen   ])
176   daddr = factory(packet[ saddroff + addrlen : saddroff + addrlen*2 ])
177   return (saddr, daddr)
178
179 #---------- address handling ----------
180
181 def ipaddr(input):
182   try:
183     r = ipaddress.IPv4Address(input)
184   except AddressValueError:
185     r = ipaddress.IPv6Address(input)
186   return r
187
188 def ipnetwork(input):
189   try:
190     r = ipaddress.IPv4Network(input)
191   except NetworkValueError:
192     r = ipaddress.IPv6Network(input)
193   return r
194
195 #---------- ipif (SLIP) subprocess ----------
196
197 class SlipStreamDecoder():
198   def __init__(self, desc, on_packet):
199     self._buffer = b''
200     self._on_packet = on_packet
201     self._desc = desc
202     self._log('__init__')
203
204   def _log(self, msg, **kwargs):
205     log_debug(DBG.SLIP_FULL, 'slip %s: %s' % (self._desc, msg), **kwargs)
206
207   def inputdata(self, data):
208     self._log('inputdata', d=data)
209     packets = slip.decode(data)
210     packets[0] = self._buffer + packets[0]
211     self._buffer = packets.pop()
212     for packet in packets:
213       self._maybe_packet(packet)
214     self._log('bufremain', d=self._buffer)
215
216   def _maybe_packet(self, packet):
217     self._log('maybepacket', d=packet)
218     if len(packet):
219       self._on_packet(packet)
220
221   def flush(self):
222     self._log('flush')
223     self._maybe_packet(self._buffer)
224     self._buffer = b''
225
226 class _IpifProcessProtocol(twisted.internet.protocol.ProcessProtocol):
227   def __init__(self, router):
228     self._router = router
229     self._decoder = SlipStreamDecoder('ipif', self.slip_on_packet)
230   def connectionMade(self): pass
231   def outReceived(self, data):
232     self._decoder.inputdata(data)
233   def slip_on_packet(self, packet):
234     (saddr, daddr) = packet_addrs(packet)
235     if saddr.is_link_local or daddr.is_link_local:
236       log_discard(packet, 'ipif', saddr, daddr, 'link-local')
237       return
238     self._router(packet, saddr, daddr)
239   def processEnded(self, status):
240     status.raiseException()
241
242 def start_ipif(command, router):
243   global ipif
244   ipif = _IpifProcessProtocol(router)
245   reactor.spawnProcess(ipif,
246                        '/bin/sh',['sh','-xc', command],
247                        childFDs={0:'w', 1:'r', 2:2},
248                        env=None)
249
250 def queue_inbound(packet):
251   log_debug(DBG.FLOW, "queue_inbound", d=packet)
252   ipif.transport.write(slip.delimiter)
253   ipif.transport.write(slip.encode(packet))
254   ipif.transport.write(slip.delimiter)
255
256 #---------- packet queue ----------
257
258 class PacketQueue():
259   def __init__(self, desc, max_queue_time):
260     self._desc = desc
261     assert(desc + '')
262     self._max_queue_time = max_queue_time
263     self._pq = collections.deque() # packets
264
265   def _log(self, dflag, msg, **kwargs):
266     log_debug(dflag, self._desc+' pq: '+msg, **kwargs)
267
268   def append(self, packet):
269     self._log(DBG.QUEUE, 'append', d=packet)
270     self._pq.append((time.monotonic(), packet))
271
272   def nonempty(self):
273     self._log(DBG.QUEUE, 'nonempty ?')
274     while True:
275       try: (queuetime, packet) = self._pq[0]
276       except IndexError:
277         self._log(DBG.QUEUE, 'nonempty ? empty.')
278         return False
279
280       age = time.monotonic() - queuetime
281       if age > self._max_queue_time:
282         # strip old packets off the front
283         self._log(DBG.QUEUE, 'dropping (old)', d=packet)
284         self._pq.popleft()
285         continue
286
287       self._log(DBG.QUEUE, 'nonempty ? nonempty.')
288       return True
289
290   def process(self, sizequery, moredata, max_batch):
291     # sizequery() should return size of batch so far
292     # moredata(s) should add s to batch
293     self._log(DBG.QUEUE, 'process...')
294     while True:
295       try: (dummy, packet) = self._pq[0]
296       except IndexError:
297         self._log(DBG.QUEUE, 'process... empty')
298         break
299
300       self._log(DBG.QUEUE_CTRL, 'process... packet', d=packet)
301
302       encoded = slip.encode(packet)
303       sofar = sizequery()  
304
305       self._log(DBG.QUEUE_CTRL,
306                 'process... (sofar=%d, max=%d) encoded' % (sofar, max_batch),
307                 d=encoded)
308
309       if sofar > 0:
310         if sofar + len(slip.delimiter) + len(encoded) > max_batch:
311           self._log(DBG.QUEUE_CTRL, 'process... overflow')
312           break
313         moredata(slip.delimiter)
314
315       moredata(encoded)
316       self._pq.popleft()
317
318 #---------- error handling ----------
319
320 _crashing = False
321
322 def crash(err):
323   global _crashing
324   _crashing = True
325   print('========== CRASH ==========', err,
326         '===========================', file=sys.stderr)
327   try: reactor.stop()
328   except twisted.internet.error.ReactorNotRunning: pass
329
330 def crash_on_defer(defer):
331   defer.addErrback(lambda err: crash(err))
332
333 def crash_on_critical(event):
334   if event.get('log_level') >= LogLevel.critical:
335     crash(twisted.logger.formatEvent(event))
336
337 #---------- config processing ----------
338
339 def _cfg_process_putatives():
340   servers = { }
341   clients = { }
342   # maps from abstract object to canonical name for cs's
343
344   def putative(cmap, abstract, canoncs):
345     try:
346       current_canoncs = cmap[abstract]
347     except KeyError:
348       pass
349     else:
350       assert(current_canoncs == canoncs)
351     cmap[abstract] = canoncs
352
353   server_pat = r'[-.0-9A-Za-z]+'
354   client_pat = r'[.:0-9a-f]+'
355   server_re = regexp.compile(server_pat)
356   serverclient_re = regexp.compile(server_pat + r' ' + client_pat)
357
358   for cs in cfg.sections():
359     if cs == 'LIMIT':
360       # plan A "[LIMIT]"
361       continue
362
363     try:
364       # plan B "[<client>]" part 1
365       ci = ipaddr(cs)
366     except AddressValueError:
367
368       if server_re.fullmatch(cs):
369         # plan C "[<servername>]"
370         putative(servers, cs, cs)
371         continue
372
373       if serverclient_re.fullmatch(cs):
374         # plan D "[<servername> <client>]" part 1
375         (pss,pcs) = cs.split(' ')
376
377         if pcs == 'LIMIT':
378           # plan E "[<servername> LIMIT]"
379           continue
380
381         try:
382           # plan D "[<servername> <client>]" part 2
383           ci = ipaddr(pc)
384         except AddressValueError:
385           # plan F "[<some thing we do not understand>]"
386           # well, we ignore this
387           print('warning: ignoring config section %s' % cs, file=sys.stderr)
388           continue
389
390         else: # no AddressValueError
391           # plan D "[<servername> <client]" part 3
392           putative(clients, ci, pcs)
393           putative(servers, pss, pss)
394           continue
395
396     else: # no AddressValueError
397       # plan B "[<client>" part 2
398       putative(clients, ci, cs)
399       continue
400
401   return (servers, clients)
402
403 def cfg_process_common(c, ss):
404   c.mtu = cfg.getint(ss, 'mtu')
405
406 def cfg_process_saddrs(c, ss):
407   class ServerAddr():
408     def __init__(self, port, addrspec):
409       self.port = port
410       # also self.addr
411       try:
412         self.addr = ipaddress.IPv4Address(addrspec)
413         self._endpointfactory = twisted.internet.endpoints.TCP4ServerEndpoint
414         self._inurl = b'%s'
415       except AddressValueError:
416         self.addr = ipaddress.IPv6Address(addrspec)
417         self._endpointfactory = twisted.internet.endpoints.TCP6ServerEndpoint
418         self._inurl = b'[%s]'
419     def make_endpoint(self):
420       return self._endpointfactory(reactor, self.port, self.addr)
421     def url(self):
422       url = b'http://' + (self._inurl % str(self.addr).encode('ascii'))
423       if self.port != 80: url += b':%d' % self.port
424       url += b'/'
425       return url
426
427   c.port = cfg.getint(ss,'port')
428   c.saddrs = [ ]
429   for addrspec in cfg.get(ss, 'addrs').split():
430     sa = ServerAddr(c.port, addrspec)
431     c.saddrs.append(sa)
432
433 def cfg_process_vnetwork(c, ss):
434   c.vnetwork = ipnetwork(cfg.get(ss,'vnetwork'))
435   if c.vnetwork.num_addresses < 3 + 2:
436     raise ValueError('vnetwork needs at least 2^3 addresses')
437
438 def cfg_process_vaddr(c, ss):
439   try:
440     c.vaddr = cfg.get(ss,'vaddr')
441   except NoOptionError:
442     cfg_process_vnetwork(c, ss)
443     c.vaddr = next(c.vnetwork.hosts())
444
445 def cfg_search_section(key,sections):
446   for section in sections:
447     if cfg.has_option(section, key):
448       return section
449   raise NoOptionError(key, repr(sections))
450
451 def cfg_search(getter,key,sections):
452   section = cfg_search_section(key,sections)
453   return getter(section, key)
454
455 def cfg_process_client_limited(cc,ss,sections,key):
456   val = cfg_search(cfg.getint, key, sections)
457   lim = cfg_search(cfg.getint, key, ['%s LIMIT' % ss, 'LIMIT'])
458   cc.__dict__[key] = min(val,lim)
459
460 def cfg_process_client_common(cc,ss,cs,ci):
461   # returns sections to search in, iff password is defined, otherwise None
462   cc.ci = ci
463
464   sections = ['%s %s' % (ss,cs),
465               cs,
466               ss,
467               'DEFAULT']
468
469   try: pwsection = cfg_search_section('password', sections)
470   except NoOptionError: return None
471     
472   pw = cfg.get(pwsection, 'password')
473   cc.password = pw.encode('utf-8')
474
475   cfg_process_client_limited(cc,ss,sections,'target_requests_outstanding')
476   cfg_process_client_limited(cc,ss,sections,'http_timeout')
477
478   return sections
479
480 def cfg_process_ipif(c, sections, varmap):
481   for d, s in varmap:
482     try: v = getattr(c, s)
483     except AttributeError: continue
484     setattr(c, d, v)
485
486   #print('CFGIPIF',repr((varmap, sections, c.__dict__)),file=sys.stderr)
487
488   section = cfg_search_section('ipif', sections)
489   c.ipif_command = cfg.get(section,'ipif', vars=c.__dict__)
490
491 #---------- startup ----------
492
493 def common_startup(process_cfg):
494   # calls process_cfg(putative_clients, putative_servers)
495
496   # ConfigParser hates #-comments after values
497   trailingcomments_re = regexp.compile(r'#.*')
498   cfg.read_string(trailingcomments_re.sub('', defcfg))
499   need_defcfg = True
500
501   def readconfig(pathname, mandatory=True):
502     def log(m, p=pathname):
503       if not DBG.CONFIG in debug_set: return
504       print('DBG.CONFIG: %s: %s' % (m, pathname))
505
506     try:
507       files = os.listdir(pathname)
508
509     except FileNotFoundError:
510       if mandatory: raise
511       log('skipped')
512       return
513
514     except NotADirectoryError:
515       cfg.read(pathname)
516       log('read file')
517       return
518
519     # is a directory
520     log('directory')
521     re = regexp.compile('[^-A-Za-z0-9_]')
522     for f in os.listdir(cdir):
523       if re.search(f): continue
524       subpath = pathname + '/' + f
525       try:
526         os.stat(subpath)
527       except FileNotFoundError:
528         log('entry skipped', subpath)
529         continue
530       cfg.read(subpath)
531       log('entry read', subpath)
532       
533   def oc_config(od,os, value, op):
534     nonlocal need_defcfg
535     need_defcfg = False
536     readconfig(value)
537
538   def dfs_less_detailed(dl):
539     return [df for df in DBG.iterconstants() if df <= dl]
540
541   def ds_default(od,os,dl,op):
542     global debug_set
543     debug_set = set(dfs_less_detailed(debug_def_detail))
544
545   def ds_select(od,os, spec, op):
546     for it in spec.split(','):
547
548       if it.startswith('-'):
549         mutator = debug_set.discard
550         it = it[1:]
551       else:
552         mutator = debug_set.add
553
554       if it == '+':
555         dfs = DBG.iterconstants()
556
557       else:
558         if it.endswith('+'):
559           mapper = dfs_less_detailed
560           it = it[0:len(it)-1]
561         else:
562           mapper = lambda x: [x]
563
564           try:
565             dfspec = DBG.lookupByName(it)
566           except ValueError:
567             optparser.error('unknown debug flag %s in --debug-select' % it)
568
569         dfs = mapper(dfspec)
570
571       for df in dfs:
572         mutator(df)
573
574   optparser.add_option('-D', '--debug',
575                        nargs=0,
576                        action='callback',
577                        help='enable default debug (to stdout)',
578                        callback= ds_default)
579
580   optparser.add_option('--debug-select',
581                        nargs=1,
582                        type='string',
583                        metavar='[-]DFLAG[+]|[-]+,...',
584                        help=
585 '''enable (`-': disable) each specified DFLAG;
586 `+': do same for all "more interesting" DFLAGSs;
587 just `+': all DFLAGs.
588   DFLAGS: ''' + ' '.join([df.name for df in DBG.iterconstants()]),
589                        action='callback',
590                        callback= ds_select)
591
592   optparser.add_option('-c', '--config',
593                        nargs=1,
594                        type='string',
595                        metavar='CONFIGFILE',
596                        dest='configfile',
597                        action='callback',
598                        callback= oc_config)
599
600   (opts, args) = optparser.parse_args()
601   if len(args): optparser.error('no non-option arguments please')
602
603   if need_defcfg:
604     readconfig('/etc/hippotat/config',   False)
605     readconfig('/etc/hippotat/config.d', False)
606
607   try:
608     (pss, pcs) = _cfg_process_putatives()
609     process_cfg(pss, pcs)
610   except (configparser.Error, ValueError):
611     traceback.print_exc(file=sys.stderr)
612     print('\nInvalid configuration, giving up.', file=sys.stderr)
613     sys.exit(12)
614
615   #print(repr(debug_set), file=sys.stderr)
616
617   log_formatter = twisted.logger.formatEventAsClassicLogText
618   stdout_obs = twisted.logger.FileLogObserver(sys.stdout, log_formatter)
619   stderr_obs = twisted.logger.FileLogObserver(sys.stderr, log_formatter)
620   pred = twisted.logger.LogLevelFilterPredicate(LogLevel.error)
621   stdsomething_obs = twisted.logger.FilteringLogObserver(
622     stderr_obs, [pred], stdout_obs
623   )
624   log_observer = twisted.logger.FilteringLogObserver(
625     stdsomething_obs, [LogNotBoringTwisted()]
626   )
627   #log_observer = stdsomething_obs
628   twisted.logger.globalLogBeginner.beginLoggingTo(
629     [ log_observer, crash_on_critical ]
630     )
631
632 def common_run():
633   log_debug(DBG.INIT, 'entering reactor')
634   if not _crashing: reactor.run()
635   print('CRASHED (end)', file=sys.stderr)