chiark / gitweb /
reorg config - will break
[hippotat.git] / hippotat / __init__.py
1 # -*- python -*-
2
3 import signal
4 signal.signal(signal.SIGINT, signal.SIG_DFL)
5
6 import sys
7 import os
8
9 from zope.interface import implementer
10
11 import twisted
12 from twisted.internet import reactor
13 import twisted.internet.endpoints
14 import twisted.logger
15 from twisted.logger import LogLevel
16 import twisted.python.constants
17 from twisted.python.constants import NamedConstant
18
19 import ipaddress
20 from ipaddress import AddressValueError
21
22 from optparse import OptionParser
23 import configparser
24 from configparser import ConfigParser
25 from configparser import NoOptionError
26
27 from functools import partial
28
29 import collections
30 import time
31 import codecs
32 import traceback
33
34 import re as regexp
35
36 import hippotat.slip as slip
37
38 class DBG(twisted.python.constants.Names):
39   INIT = NamedConstant()
40   CONFIG = NamedConstant()
41   ROUTE = NamedConstant()
42   DROP = NamedConstant()
43   FLOW = NamedConstant()
44   HTTP = NamedConstant()
45   TWISTED = NamedConstant()
46   QUEUE = NamedConstant()
47   HTTP_CTRL = NamedConstant()
48   QUEUE_CTRL = NamedConstant()
49   HTTP_FULL = NamedConstant()
50   CTRL_DUMP = NamedConstant()
51   SLIP_FULL = NamedConstant()
52   DATA_COMPLETE = NamedConstant()
53
54 _hex_codec = codecs.getencoder('hex_codec')
55
56 #---------- logging ----------
57
58 org_stderr = sys.stderr
59
60 log = twisted.logger.Logger()
61
62 debug_set = set()
63 debug_def_detail = DBG.HTTP
64
65 def log_debug(dflag, msg, idof=None, d=None):
66   if dflag not in debug_set: return
67   #print('---------------->',repr((dflag, msg, idof, d)), file=sys.stderr)
68   if idof is not None:
69     msg = '[%#x] %s' % (id(idof), msg)
70   if d is not None:
71     trunc = ''
72     if not DBG.DATA_COMPLETE in debug_set:
73       if len(d) > 64:
74         d = d[0:64]
75         trunc = '...'
76     d = _hex_codec(d)[0].decode('ascii')
77     msg += ' ' + d + trunc
78   log.info('{dflag} {msgcore}', dflag=dflag, msgcore=msg)
79
80 @implementer(twisted.logger.ILogFilterPredicate)
81 class LogNotBoringTwisted:
82   def __call__(self, event):
83     yes = twisted.logger.PredicateResult.yes
84     no  = twisted.logger.PredicateResult.no
85     try:
86       if event.get('log_level') != LogLevel.info:
87         return yes
88       dflag = event.get('dflag')
89       if dflag                         in debug_set: return yes
90       if dflag is None and DBG.TWISTED in debug_set: return yes
91       return no
92     except Exception:
93       print(traceback.format_exc(), file=org_stderr)
94       return yes
95
96 #---------- default config ----------
97
98 defcfg = '''
99 [DEFAULT]
100 max_batch_down = 65536
101 max_queue_time = 10
102 target_requests_outstanding = 3
103 http_timeout = 30
104 http_timeout_grace = 5
105 max_requests_outstanding = 6
106 max_batch_up = 4000
107 http_retry = 5
108 port = 80
109
110 #[server] or [<client>] overrides
111 ipif = userv root ipif %(local)s,%(peer)s,%(mtu)s,slip %(rnets)s
112
113 # relating to virtual network
114 mtu = 1500
115
116 [SERVER]
117 server = SERVER
118 # addrs = 127.0.0.1 ::1
119 # url
120
121 # relating to virtual network
122 routes = ''
123 vnetwork = 172.24.230.192
124 # network = <prefix>/<len>
125 # server  = <ipaddr>
126 # relay   = <ipaddr>
127
128
129 # [<client-ip4-or-ipv6-address>]
130 # password = <password>    # used by both, must match
131
132 [LIMIT]
133 max_batch_down = 262144
134 max_queue_time = 121
135 http_timeout = 121
136 target_requests_outstanding = 10
137 '''
138
139 # these need to be defined here so that they can be imported by import *
140 cfg = ConfigParser(strict=False)
141 optparser = OptionParser()
142
143 _mimetrans = bytes.maketrans(b'-'+slip.esc, slip.esc+b'-')
144 def mime_translate(s):
145   # SLIP-encoded packets cannot contain ESC ESC.
146   # Swap `-' and ESC.  The result cannot contain `--'
147   return s.translate(_mimetrans)
148
149 class ConfigResults:
150   def __init__(self):
151     pass
152   def __repr__(self):
153     return 'ConfigResults('+repr(self.__dict__)+')'
154
155 def log_discard(packet, iface, saddr, daddr, why):
156   log_debug(DBG.DROP,
157             'discarded packet [%s] %s -> %s: %s' % (iface, saddr, daddr, why),
158             d=packet)
159
160 #---------- packet parsing ----------
161
162 def packet_addrs(packet):
163   version = packet[0] >> 4
164   if version == 4:
165     addrlen = 4
166     saddroff = 3*4
167     factory = ipaddress.IPv4Address
168   elif version == 6:
169     addrlen = 16
170     saddroff = 2*4
171     factory = ipaddress.IPv6Address
172   else:
173     raise ValueError('unsupported IP version %d' % version)
174   saddr = factory(packet[ saddroff           : saddroff + addrlen   ])
175   daddr = factory(packet[ saddroff + addrlen : saddroff + addrlen*2 ])
176   return (saddr, daddr)
177
178 #---------- address handling ----------
179
180 def ipaddr(input):
181   try:
182     r = ipaddress.IPv4Address(input)
183   except AddressValueError:
184     r = ipaddress.IPv6Address(input)
185   return r
186
187 def ipnetwork(input):
188   try:
189     r = ipaddress.IPv4Network(input)
190   except NetworkValueError:
191     r = ipaddress.IPv6Network(input)
192   return r
193
194 #---------- ipif (SLIP) subprocess ----------
195
196 class SlipStreamDecoder():
197   def __init__(self, desc, on_packet):
198     self._buffer = b''
199     self._on_packet = on_packet
200     self._desc = desc
201     self._log('__init__')
202
203   def _log(self, msg, **kwargs):
204     log_debug(DBG.SLIP_FULL, 'slip %s: %s' % (self._desc, msg), **kwargs)
205
206   def inputdata(self, data):
207     self._log('inputdata', d=data)
208     packets = slip.decode(data)
209     packets[0] = self._buffer + packets[0]
210     self._buffer = packets.pop()
211     for packet in packets:
212       self._maybe_packet(packet)
213     self._log('bufremain', d=self._buffer)
214
215   def _maybe_packet(self, packet):
216     self._log('maybepacket', d=packet)
217     if len(packet):
218       self._on_packet(packet)
219
220   def flush(self):
221     self._log('flush')
222     self._maybe_packet(self._buffer)
223     self._buffer = b''
224
225 class _IpifProcessProtocol(twisted.internet.protocol.ProcessProtocol):
226   def __init__(self, router):
227     self._router = router
228     self._decoder = SlipStreamDecoder('ipif', self.slip_on_packet)
229   def connectionMade(self): pass
230   def outReceived(self, data):
231     self._decoder.inputdata(data)
232   def slip_on_packet(self, packet):
233     (saddr, daddr) = packet_addrs(packet)
234     if saddr.is_link_local or daddr.is_link_local:
235       log_discard(packet, 'ipif', saddr, daddr, 'link-local')
236       return
237     self._router(packet, saddr, daddr)
238   def processEnded(self, status):
239     status.raiseException()
240
241 def start_ipif(command, router):
242   global ipif
243   ipif = _IpifProcessProtocol(router)
244   reactor.spawnProcess(ipif,
245                        '/bin/sh',['sh','-xc', command],
246                        childFDs={0:'w', 1:'r', 2:2},
247                        env=None)
248
249 def queue_inbound(packet):
250   log_debug(DBG.FLOW, "queue_inbound", d=packet)
251   ipif.transport.write(slip.delimiter)
252   ipif.transport.write(slip.encode(packet))
253   ipif.transport.write(slip.delimiter)
254
255 #---------- packet queue ----------
256
257 class PacketQueue():
258   def __init__(self, desc, max_queue_time):
259     self._desc = desc
260     assert(desc + '')
261     self._max_queue_time = max_queue_time
262     self._pq = collections.deque() # packets
263
264   def _log(self, dflag, msg, **kwargs):
265     log_debug(dflag, self._desc+' pq: '+msg, **kwargs)
266
267   def append(self, packet):
268     self._log(DBG.QUEUE, 'append', d=packet)
269     self._pq.append((time.monotonic(), packet))
270
271   def nonempty(self):
272     self._log(DBG.QUEUE, 'nonempty ?')
273     while True:
274       try: (queuetime, packet) = self._pq[0]
275       except IndexError:
276         self._log(DBG.QUEUE, 'nonempty ? empty.')
277         return False
278
279       age = time.monotonic() - queuetime
280       if age > self._max_queue_time:
281         # strip old packets off the front
282         self._log(DBG.QUEUE, 'dropping (old)', d=packet)
283         self._pq.popleft()
284         continue
285
286       self._log(DBG.QUEUE, 'nonempty ? nonempty.')
287       return True
288
289   def process(self, sizequery, moredata, max_batch):
290     # sizequery() should return size of batch so far
291     # moredata(s) should add s to batch
292     self._log(DBG.QUEUE, 'process...')
293     while True:
294       try: (dummy, packet) = self._pq[0]
295       except IndexError:
296         self._log(DBG.QUEUE, 'process... empty')
297         break
298
299       self._log(DBG.QUEUE_CTRL, 'process... packet', d=packet)
300
301       encoded = slip.encode(packet)
302       sofar = sizequery()  
303
304       self._log(DBG.QUEUE_CTRL,
305                 'process... (sofar=%d, max=%d) encoded' % (sofar, max_batch),
306                 d=encoded)
307
308       if sofar > 0:
309         if sofar + len(slip.delimiter) + len(encoded) > max_batch:
310           self._log(DBG.QUEUE_CTRL, 'process... overflow')
311           break
312         moredata(slip.delimiter)
313
314       moredata(encoded)
315       self._pq.popleft()
316
317 #---------- error handling ----------
318
319 _crashing = False
320
321 def crash(err):
322   global _crashing
323   _crashing = True
324   print('========== CRASH ==========', err,
325         '===========================', file=sys.stderr)
326   try: reactor.stop()
327   except twisted.internet.error.ReactorNotRunning: pass
328
329 def crash_on_defer(defer):
330   defer.addErrback(lambda err: crash(err))
331
332 def crash_on_critical(event):
333   if event.get('log_level') >= LogLevel.critical:
334     crash(twisted.logger.formatEvent(event))
335
336 #---------- config processing ----------
337
338 def _cfg_process_putatives():
339   servers = { }
340   clients = { }
341   # maps from abstract object to canonical name for cs's
342
343   def putative(cmap, abstract, canoncs):
344     try:
345       current_canoncs = cmap[abstract]
346     except KeyError:
347       pass
348     else:
349       assert(current_canoncs == canoncs)
350     cmap[abstract] = canoncs
351
352   server_pat = r'[-.0-9A-Za-z]+'
353   client_pat = r'[.:0-9a-f]+'
354   server_re = regexp.compile(server_pat)
355   serverclient_re = regexp.compile(server_pat + r' ' + client_pat)
356
357   for cs in cfg.sections():
358     if cs = 'LIMIT':
359       # plan A "[LIMIT]"
360       continue
361
362     try:
363       # plan B "[<client>]" part 1
364       ci = ipaddr(cs)
365     except AddressValueError:
366
367       if server_re.fullmatch(cs):
368         # plan C "[<servername>]"
369         putative(servers, cs, cs)
370         continue
371
372       if serverclient_re.fullmatch(cs):
373         # plan D "[<servername> <client>]" part 1
374         (pss,pcs) = cs.split(' ')
375
376         if pcs = 'LIMIT':
377           # plan E "[<servername> LIMIT]"
378           continue
379
380         try:
381           # plan D "[<servername> <client>]" part 2
382           ci = ipaddr(pc)
383         except AddressValueError:
384           # plan F "[<some thing we do not understand>]"
385           # well, we ignore this
386           print('warning: ignoring config section %s' % cs, file=sys.stderr)
387           continue
388
389         else: # no AddressValueError
390           # plan D "[<servername> <client]" part 3
391           putative(clients, ci, pcs)
392           putative(servers, pss, pss)
393           continue
394
395     else: # no AddressValueError
396       # plan B "[<client>" part 2
397       putative(clients, ci, cs)
398       continue
399
400   return (servers, clients)
401
402 def cfg_process_common(ss):
403   c.mtu = cfg.getint(ss, 'mtu')
404
405 def cfg_process_saddrs(c, ss):
406   class ServerAddr():
407     def __init__(self, port, addrspec):
408       self.port = port
409       # also self.addr
410       try:
411         self.addr = ipaddress.IPv4Address(addrspec)
412         self._endpointfactory = twisted.internet.endpoints.TCP4ServerEndpoint
413         self._inurl = b'%s'
414       except AddressValueError:
415         self.addr = ipaddress.IPv6Address(addrspec)
416         self._endpointfactory = twisted.internet.endpoints.TCP6ServerEndpoint
417         self._inurl = b'[%s]'
418     def make_endpoint(self):
419       return self._endpointfactory(reactor, self.port, self.addr)
420     def url(self):
421       url = b'http://' + (self._inurl % str(self.addr).encode('ascii'))
422       if self.port != 80: url += b':%d' % self.port
423       url += b'/'
424       return url
425
426   c.port = cfg.getint(ss,'port')
427   c.saddrs = [ ]
428   for addrspec in cfg.get(ss, 'addrs').split():
429     sa = ServerAddr(c.port, addrspec)
430     c.saddrs.append(sa)
431
432 def cfg_process_vnetwork(c, ss):
433   c.network = ipnetwork(cfg.get(ss,'network'))
434   if c.network.num_addresses < 3 + 2:
435     raise ValueError('network needs at least 2^3 addresses')
436
437 def cfg_process_vaddr():
438   try:
439     c.server = cfg.get('virtual','server')
440   except NoOptionError:
441     process_cfg_network()
442     c.server = next(c.network.hosts())
443
444 def cfg_search_section(key,sections):
445   for section in sections:
446     if cfg.has_option(section, key):
447       return section
448   raise NoOptionError('missing %s %s' % (key, repr(sections)))
449
450 def cfg_search(getter,key,sections):
451   section = cfg_search_section(key,sections)
452   return getter(section, key)
453
454 def cfg_process_client_limited(cc,ss,sections,key):
455   val = cfg_search(cfg.getint, key, sections)
456   lim = cfg_search(cfg.getint, key, '%s LIMIT' % ss, 'LIMIT')
457   cc.__dict__[key] = min(val,lim)
458
459 def cfg_process_client_common(cc,ss,cs,ci):
460   # returns sections to search in, iff password is defined, otherwise None
461   cc.ci = ci
462
463   sections = ['%s %s' % section,
464               cs,
465               ss,
466               'DEFAULT']
467
468   try: pwsection = cfg_search_section('password', sections)
469   except NoOptionError: return None
470     
471   pw = cfg.get(pwsection, 'password')
472   pw = pw.encode('utf-8')
473
474   cfg_process_client_limited(cc,ss,sections,'target_requests_outstanding')
475   cfg_process_client_limited(cc,ss,sections,'http_timeout')
476
477   return sections
478
479 def process_cfg_ipif(c, sections, varmap):
480   for d, s in varmap:
481     try: v = getattr(c, s)
482     except AttributeError: continue
483     setattr(c, d, v)
484
485   section = cfg_search_section('ipif', sections)
486   c.ipif_command = cfg.get(section,'ipif', vars=c.__dict__)
487
488 #---------- startup ----------
489
490 def common_startup(process_cfg):
491   # calls process_cfg(putative_clients, putative_servers)
492
493   # ConfigParser hates #-comments after values
494   trailingcomments_re = regexp.compile(r'#.*')
495   cfg.read_string(trailingcomments_re.sub('', defcfg))
496   need_defcfg = True
497
498   def readconfig(pathname, mandatory=True):
499     def log(m, p=pathname):
500       if not DBG.CONFIG in debug_set: return
501       print('DBG.CONFIG: %s: %s' % (m, pathname))
502
503     try:
504       files = os.listdir(pathname)
505
506     except FileNotFoundError:
507       if mandatory: raise
508       log('skipped')
509       return
510
511     except NotADirectoryError:
512       cfg.read(pathname)
513       log('read file')
514       return
515
516     # is a directory
517     log('directory')
518     re = regexp.compile('[^-A-Za-z0-9_]')
519     for f in os.listdir(cdir):
520       if re.search(f): continue
521       subpath = pathname + '/' + f
522       try:
523         os.stat(subpath)
524       except FileNotFoundError:
525         log('entry skipped', subpath)
526         continue
527       cfg.read(subpath)
528       log('entry read', subpath)
529       
530   def oc_config(od,os, value, op):
531     nonlocal need_defcfg
532     need_defcfg = False
533     readconfig(value)
534
535   def dfs_less_detailed(dl):
536     return [df for df in DBG.iterconstants() if df <= dl]
537
538   def ds_default(od,os,dl,op):
539     global debug_set
540     debug_set = set(dfs_less_detailed(debug_def_detail))
541
542   def ds_select(od,os, spec, op):
543     for it in spec.split(','):
544
545       if it.startswith('-'):
546         mutator = debug_set.discard
547         it = it[1:]
548       else:
549         mutator = debug_set.add
550
551       if it == '+':
552         dfs = DBG.iterconstants()
553
554       else:
555         if it.endswith('+'):
556           mapper = dfs_less_detailed
557           it = it[0:len(it)-1]
558         else:
559           mapper = lambda x: [x]
560
561           try:
562             dfspec = DBG.lookupByName(it)
563           except ValueError:
564             optparser.error('unknown debug flag %s in --debug-select' % it)
565
566         dfs = mapper(dfspec)
567
568       for df in dfs:
569         mutator(df)
570
571   optparser.add_option('-D', '--debug',
572                        nargs=0,
573                        action='callback',
574                        help='enable default debug (to stdout)',
575                        callback= ds_default)
576
577   optparser.add_option('--debug-select',
578                        nargs=1,
579                        type='string',
580                        metavar='[-]DFLAG[+]|[-]+,...',
581                        help=
582 '''enable (`-': disable) each specified DFLAG;
583 `+': do same for all "more interesting" DFLAGSs;
584 just `+': all DFLAGs.
585   DFLAGS: ''' + ' '.join([df.name for df in DBG.iterconstants()]),
586                        action='callback',
587                        callback= ds_select)
588
589   optparser.add_option('-c', '--config',
590                        nargs=1,
591                        type='string',
592                        metavar='CONFIGFILE',
593                        dest='configfile',
594                        action='callback',
595                        callback= oc_config)
596
597   (opts, args) = optparser.parse_args()
598   if len(args): optparser.error('no non-option arguments please')
599
600   if need_defcfg:
601     readconfig('/etc/hippotat/config',   False)
602     readconfig('/etc/hippotat/config.d', False)
603
604   try:
605     (pss, pcs) = process_cfg_putatives()
606     process_cfg(pss, pcs)
607   except (configparser.Error, ValueError):
608     traceback.print_exc(file=sys.stderr)
609     print('\nInvalid configuration, giving up.', file=sys.stderr)
610     sys.exit(12)
611
612   #print(repr(debug_set), file=sys.stderr)
613
614   log_formatter = twisted.logger.formatEventAsClassicLogText
615   stdout_obs = twisted.logger.FileLogObserver(sys.stdout, log_formatter)
616   stderr_obs = twisted.logger.FileLogObserver(sys.stderr, log_formatter)
617   pred = twisted.logger.LogLevelFilterPredicate(LogLevel.error)
618   stdsomething_obs = twisted.logger.FilteringLogObserver(
619     stderr_obs, [pred], stdout_obs
620   )
621   log_observer = twisted.logger.FilteringLogObserver(
622     stdsomething_obs, [LogNotBoringTwisted()]
623   )
624   #log_observer = stdsomething_obs
625   twisted.logger.globalLogBeginner.beginLoggingTo(
626     [ log_observer, crash_on_critical ]
627     )
628
629 def common_run():
630   log_debug(DBG.INIT, 'entering reactor')
631   if not _crashing: reactor.run()
632   print('CRASHED (end)', file=sys.stderr)