chiark / gitweb /
wip fixes
[hippotat.git] / hippotat / __init__.py
1 # -*- python -*-
2
3 import signal
4 signal.signal(signal.SIGINT, signal.SIG_DFL)
5
6 import sys
7 import os
8
9 from zope.interface import implementer
10
11 import twisted
12 from twisted.internet import reactor
13 import twisted.internet.endpoints
14 import twisted.logger
15 from twisted.logger import LogLevel
16 import twisted.python.constants
17 from twisted.python.constants import NamedConstant
18
19 import ipaddress
20 from ipaddress import AddressValueError
21
22 from optparse import OptionParser
23 import configparser
24 from configparser import ConfigParser
25 from configparser import NoOptionError
26
27 from functools import partial
28
29 import collections
30 import time
31 import codecs
32 import traceback
33
34 import re as regexp
35
36 import hippotat.slip as slip
37
38 class DBG(twisted.python.constants.Names):
39   INIT = NamedConstant()
40   CONFIG = NamedConstant()
41   ROUTE = NamedConstant()
42   DROP = NamedConstant()
43   FLOW = NamedConstant()
44   HTTP = NamedConstant()
45   TWISTED = NamedConstant()
46   QUEUE = NamedConstant()
47   HTTP_CTRL = NamedConstant()
48   QUEUE_CTRL = NamedConstant()
49   HTTP_FULL = NamedConstant()
50   CTRL_DUMP = NamedConstant()
51   SLIP_FULL = NamedConstant()
52   DATA_COMPLETE = NamedConstant()
53
54 _hex_codec = codecs.getencoder('hex_codec')
55
56 #---------- logging ----------
57
58 org_stderr = sys.stderr
59
60 log = twisted.logger.Logger()
61
62 debug_set = set()
63 debug_def_detail = DBG.HTTP
64
65 def log_debug(dflag, msg, idof=None, d=None):
66   if dflag not in debug_set: return
67   #print('---------------->',repr((dflag, msg, idof, d)), file=sys.stderr)
68   if idof is not None:
69     msg = '[%#x] %s' % (id(idof), msg)
70   if d is not None:
71     trunc = ''
72     if not DBG.DATA_COMPLETE in debug_set:
73       if len(d) > 64:
74         d = d[0:64]
75         trunc = '...'
76     d = _hex_codec(d)[0].decode('ascii')
77     msg += ' ' + d + trunc
78   log.info('{dflag} {msgcore}', dflag=dflag, msgcore=msg)
79
80 @implementer(twisted.logger.ILogFilterPredicate)
81 class LogNotBoringTwisted:
82   def __call__(self, event):
83     yes = twisted.logger.PredicateResult.yes
84     no  = twisted.logger.PredicateResult.no
85     try:
86       if event.get('log_level') != LogLevel.info:
87         return yes
88       dflag = event.get('dflag')
89       if dflag                         in debug_set: return yes
90       if dflag is None and DBG.TWISTED in debug_set: return yes
91       return no
92     except Exception:
93       print(traceback.format_exc(), file=org_stderr)
94       return yes
95
96 #---------- default config ----------
97
98 defcfg = '''
99 [DEFAULT]
100 max_batch_down = 65536
101 max_queue_time = 10
102 target_requests_outstanding = 3
103 http_timeout = 30
104 http_timeout_grace = 5
105 max_requests_outstanding = 6
106 max_batch_up = 4000
107 http_retry = 5
108 port = 80
109 vroutes = ''
110
111 #[server] or [<client>] overrides
112 ipif = userv root ipif %(local)s,%(peer)s,%(mtu)s,slip %(rnets)s
113
114 # relating to virtual network
115 mtu = 1500
116
117 [SERVER]
118 server = SERVER
119 # addrs = 127.0.0.1 ::1
120 # url
121
122 # relating to virtual network
123 vvnetwork = 172.24.230.192
124 # vnetwork = <prefix>/<len>
125 # vadd  r  = <ipaddr>
126 # vrelay   = <ipaddr>
127
128
129 # [<client-ip4-or-ipv6-address>]
130 # password = <password>    # used by both, must match
131
132 [LIMIT]
133 max_batch_down = 262144
134 max_queue_time = 121
135 http_timeout = 121
136 target_requests_outstanding = 10
137 '''
138
139 # these need to be defined here so that they can be imported by import *
140 cfg = ConfigParser(strict=False)
141 optparser = OptionParser()
142
143 _mimetrans = bytes.maketrans(b'-'+slip.esc, slip.esc+b'-')
144 def mime_translate(s):
145   # SLIP-encoded packets cannot contain ESC ESC.
146   # Swap `-' and ESC.  The result cannot contain `--'
147   return s.translate(_mimetrans)
148
149 class ConfigResults:
150   def __init__(self):
151     pass
152   def __repr__(self):
153     return 'ConfigResults('+repr(self.__dict__)+')'
154
155 def log_discard(packet, iface, saddr, daddr, why):
156   log_debug(DBG.DROP,
157             'discarded packet [%s] %s -> %s: %s' % (iface, saddr, daddr, why),
158             d=packet)
159
160 #---------- packet parsing ----------
161
162 def packet_addrs(packet):
163   version = packet[0] >> 4
164   if version == 4:
165     addrlen = 4
166     saddroff = 3*4
167     factory = ipaddress.IPv4Address
168   elif version == 6:
169     addrlen = 16
170     saddroff = 2*4
171     factory = ipaddress.IPv6Address
172   else:
173     raise ValueError('unsupported IP version %d' % version)
174   saddr = factory(packet[ saddroff           : saddroff + addrlen   ])
175   daddr = factory(packet[ saddroff + addrlen : saddroff + addrlen*2 ])
176   return (saddr, daddr)
177
178 #---------- address handling ----------
179
180 def ipaddr(input):
181   try:
182     r = ipaddress.IPv4Address(input)
183   except AddressValueError:
184     r = ipaddress.IPv6Address(input)
185   return r
186
187 def ipnetwork(input):
188   try:
189     r = ipaddress.IPv4Network(input)
190   except NetworkValueError:
191     r = ipaddress.IPv6Network(input)
192   return r
193
194 #---------- ipif (SLIP) subprocess ----------
195
196 class SlipStreamDecoder():
197   def __init__(self, desc, on_packet):
198     self._buffer = b''
199     self._on_packet = on_packet
200     self._desc = desc
201     self._log('__init__')
202
203   def _log(self, msg, **kwargs):
204     log_debug(DBG.SLIP_FULL, 'slip %s: %s' % (self._desc, msg), **kwargs)
205
206   def inputdata(self, data):
207     self._log('inputdata', d=data)
208     packets = slip.decode(data)
209     packets[0] = self._buffer + packets[0]
210     self._buffer = packets.pop()
211     for packet in packets:
212       self._maybe_packet(packet)
213     self._log('bufremain', d=self._buffer)
214
215   def _maybe_packet(self, packet):
216     self._log('maybepacket', d=packet)
217     if len(packet):
218       self._on_packet(packet)
219
220   def flush(self):
221     self._log('flush')
222     self._maybe_packet(self._buffer)
223     self._buffer = b''
224
225 class _IpifProcessProtocol(twisted.internet.protocol.ProcessProtocol):
226   def __init__(self, router):
227     self._router = router
228     self._decoder = SlipStreamDecoder('ipif', self.slip_on_packet)
229   def connectionMade(self): pass
230   def outReceived(self, data):
231     self._decoder.inputdata(data)
232   def slip_on_packet(self, packet):
233     (saddr, daddr) = packet_addrs(packet)
234     if saddr.is_link_local or daddr.is_link_local:
235       log_discard(packet, 'ipif', saddr, daddr, 'link-local')
236       return
237     self._router(packet, saddr, daddr)
238   def processEnded(self, status):
239     status.raiseException()
240
241 def start_ipif(command, router):
242   global ipif
243   ipif = _IpifProcessProtocol(router)
244   reactor.spawnProcess(ipif,
245                        '/bin/sh',['sh','-xc', command],
246                        childFDs={0:'w', 1:'r', 2:2},
247                        env=None)
248
249 def queue_inbound(packet):
250   log_debug(DBG.FLOW, "queue_inbound", d=packet)
251   ipif.transport.write(slip.delimiter)
252   ipif.transport.write(slip.encode(packet))
253   ipif.transport.write(slip.delimiter)
254
255 #---------- packet queue ----------
256
257 class PacketQueue():
258   def __init__(self, desc, max_queue_time):
259     self._desc = desc
260     assert(desc + '')
261     self._max_queue_time = max_queue_time
262     self._pq = collections.deque() # packets
263
264   def _log(self, dflag, msg, **kwargs):
265     log_debug(dflag, self._desc+' pq: '+msg, **kwargs)
266
267   def append(self, packet):
268     self._log(DBG.QUEUE, 'append', d=packet)
269     self._pq.append((time.monotonic(), packet))
270
271   def nonempty(self):
272     self._log(DBG.QUEUE, 'nonempty ?')
273     while True:
274       try: (queuetime, packet) = self._pq[0]
275       except IndexError:
276         self._log(DBG.QUEUE, 'nonempty ? empty.')
277         return False
278
279       age = time.monotonic() - queuetime
280       if age > self._max_queue_time:
281         # strip old packets off the front
282         self._log(DBG.QUEUE, 'dropping (old)', d=packet)
283         self._pq.popleft()
284         continue
285
286       self._log(DBG.QUEUE, 'nonempty ? nonempty.')
287       return True
288
289   def process(self, sizequery, moredata, max_batch):
290     # sizequery() should return size of batch so far
291     # moredata(s) should add s to batch
292     self._log(DBG.QUEUE, 'process...')
293     while True:
294       try: (dummy, packet) = self._pq[0]
295       except IndexError:
296         self._log(DBG.QUEUE, 'process... empty')
297         break
298
299       self._log(DBG.QUEUE_CTRL, 'process... packet', d=packet)
300
301       encoded = slip.encode(packet)
302       sofar = sizequery()  
303
304       self._log(DBG.QUEUE_CTRL,
305                 'process... (sofar=%d, max=%d) encoded' % (sofar, max_batch),
306                 d=encoded)
307
308       if sofar > 0:
309         if sofar + len(slip.delimiter) + len(encoded) > max_batch:
310           self._log(DBG.QUEUE_CTRL, 'process... overflow')
311           break
312         moredata(slip.delimiter)
313
314       moredata(encoded)
315       self._pq.popleft()
316
317 #---------- error handling ----------
318
319 _crashing = False
320
321 def crash(err):
322   global _crashing
323   _crashing = True
324   print('========== CRASH ==========', err,
325         '===========================', file=sys.stderr)
326   try: reactor.stop()
327   except twisted.internet.error.ReactorNotRunning: pass
328
329 def crash_on_defer(defer):
330   defer.addErrback(lambda err: crash(err))
331
332 def crash_on_critical(event):
333   if event.get('log_level') >= LogLevel.critical:
334     crash(twisted.logger.formatEvent(event))
335
336 #---------- config processing ----------
337
338 def _cfg_process_putatives():
339   servers = { }
340   clients = { }
341   # maps from abstract object to canonical name for cs's
342
343   def putative(cmap, abstract, canoncs):
344     try:
345       current_canoncs = cmap[abstract]
346     except KeyError:
347       pass
348     else:
349       assert(current_canoncs == canoncs)
350     cmap[abstract] = canoncs
351
352   server_pat = r'[-.0-9A-Za-z]+'
353   client_pat = r'[.:0-9a-f]+'
354   server_re = regexp.compile(server_pat)
355   serverclient_re = regexp.compile(server_pat + r' ' + client_pat)
356
357   for cs in cfg.sections():
358     if cs == 'LIMIT':
359       # plan A "[LIMIT]"
360       continue
361
362     try:
363       # plan B "[<client>]" part 1
364       ci = ipaddr(cs)
365     except AddressValueError:
366
367       if server_re.fullmatch(cs):
368         # plan C "[<servername>]"
369         putative(servers, cs, cs)
370         continue
371
372       if serverclient_re.fullmatch(cs):
373         # plan D "[<servername> <client>]" part 1
374         (pss,pcs) = cs.split(' ')
375
376         if pcs == 'LIMIT':
377           # plan E "[<servername> LIMIT]"
378           continue
379
380         try:
381           # plan D "[<servername> <client>]" part 2
382           ci = ipaddr(pc)
383         except AddressValueError:
384           # plan F "[<some thing we do not understand>]"
385           # well, we ignore this
386           print('warning: ignoring config section %s' % cs, file=sys.stderr)
387           continue
388
389         else: # no AddressValueError
390           # plan D "[<servername> <client]" part 3
391           putative(clients, ci, pcs)
392           putative(servers, pss, pss)
393           continue
394
395     else: # no AddressValueError
396       # plan B "[<client>" part 2
397       putative(clients, ci, cs)
398       continue
399
400   return (servers, clients)
401
402 def cfg_process_common(ss):
403   c.mtu = cfg.getint(ss, 'mtu')
404
405 def cfg_process_saddrs(c, ss):
406   class ServerAddr():
407     def __init__(self, port, addrspec):
408       self.port = port
409       # also self.addr
410       try:
411         self.addr = ipaddress.IPv4Address(addrspec)
412         self._endpointfactory = twisted.internet.endpoints.TCP4ServerEndpoint
413         self._inurl = b'%s'
414       except AddressValueError:
415         self.addr = ipaddress.IPv6Address(addrspec)
416         self._endpointfactory = twisted.internet.endpoints.TCP6ServerEndpoint
417         self._inurl = b'[%s]'
418     def make_endpoint(self):
419       return self._endpointfactory(reactor, self.port, self.addr)
420     def url(self):
421       url = b'http://' + (self._inurl % str(self.addr).encode('ascii'))
422       if self.port != 80: url += b':%d' % self.port
423       url += b'/'
424       return url
425
426   c.port = cfg.getint(ss,'port')
427   c.saddrs = [ ]
428   for addrspec in cfg.get(ss, 'addrs').split():
429     sa = ServerAddr(c.port, addrspec)
430     c.saddrs.append(sa)
431
432 def cfg_process_vnetwork(c, ss):
433   c.network = ipnetwork(cfg.get(ss,'network'))
434   if c.network.num_addresses < 3 + 2:
435     raise ValueError('network needs at least 2^3 addresses')
436
437 def cfg_process_vaddr(c, ss):
438   try:
439     c.vaddr = cfg.get(ss,'server')
440   except NoOptionError:
441     cfg_process_vnetwork(c, ss)
442     c.vaddr = next(c.network.hosts())
443
444 def cfg_search_section(key,sections):
445   for section in sections:
446     if cfg.has_option(section, key):
447       return section
448   raise NoOptionError(key, repr(sections))
449
450 def cfg_search(getter,key,sections):
451   section = cfg_search_section(key,sections)
452   return getter(section, key)
453
454 def cfg_process_client_limited(cc,ss,sections,key):
455   val = cfg_search(cfg.getint, key, sections)
456   lim = cfg_search(cfg.getint, key, ['%s LIMIT' % ss, 'LIMIT'])
457   cc.__dict__[key] = min(val,lim)
458
459 def cfg_process_client_common(cc,ss,cs,ci):
460   # returns sections to search in, iff password is defined, otherwise None
461   cc.ci = ci
462
463   sections = ['%s %s' % (ss,cs),
464               cs,
465               ss,
466               'DEFAULT']
467
468   try: pwsection = cfg_search_section('password', sections)
469   except NoOptionError: return None
470     
471   pw = cfg.get(pwsection, 'password')
472   pw = pw.encode('utf-8')
473
474   cfg_process_client_limited(cc,ss,sections,'target_requests_outstanding')
475   cfg_process_client_limited(cc,ss,sections,'http_timeout')
476
477   return sections
478
479 def cfg_process_ipif(c, sections, varmap):
480   for d, s in varmap:
481     try: v = getattr(c, s)
482     except AttributeError: continue
483     setattr(c, d, v)
484
485   print('CFGIPIF',repr((varmap, sections, c.__dict__)),file=sys.stderr)
486
487   section = cfg_search_section('ipif', sections)
488   c.ipif_command = cfg.get(section,'ipif', vars=c.__dict__)
489
490 #---------- startup ----------
491
492 def common_startup(process_cfg):
493   # calls process_cfg(putative_clients, putative_servers)
494
495   # ConfigParser hates #-comments after values
496   trailingcomments_re = regexp.compile(r'#.*')
497   cfg.read_string(trailingcomments_re.sub('', defcfg))
498   need_defcfg = True
499
500   def readconfig(pathname, mandatory=True):
501     def log(m, p=pathname):
502       if not DBG.CONFIG in debug_set: return
503       print('DBG.CONFIG: %s: %s' % (m, pathname))
504
505     try:
506       files = os.listdir(pathname)
507
508     except FileNotFoundError:
509       if mandatory: raise
510       log('skipped')
511       return
512
513     except NotADirectoryError:
514       cfg.read(pathname)
515       log('read file')
516       return
517
518     # is a directory
519     log('directory')
520     re = regexp.compile('[^-A-Za-z0-9_]')
521     for f in os.listdir(cdir):
522       if re.search(f): continue
523       subpath = pathname + '/' + f
524       try:
525         os.stat(subpath)
526       except FileNotFoundError:
527         log('entry skipped', subpath)
528         continue
529       cfg.read(subpath)
530       log('entry read', subpath)
531       
532   def oc_config(od,os, value, op):
533     nonlocal need_defcfg
534     need_defcfg = False
535     readconfig(value)
536
537   def dfs_less_detailed(dl):
538     return [df for df in DBG.iterconstants() if df <= dl]
539
540   def ds_default(od,os,dl,op):
541     global debug_set
542     debug_set = set(dfs_less_detailed(debug_def_detail))
543
544   def ds_select(od,os, spec, op):
545     for it in spec.split(','):
546
547       if it.startswith('-'):
548         mutator = debug_set.discard
549         it = it[1:]
550       else:
551         mutator = debug_set.add
552
553       if it == '+':
554         dfs = DBG.iterconstants()
555
556       else:
557         if it.endswith('+'):
558           mapper = dfs_less_detailed
559           it = it[0:len(it)-1]
560         else:
561           mapper = lambda x: [x]
562
563           try:
564             dfspec = DBG.lookupByName(it)
565           except ValueError:
566             optparser.error('unknown debug flag %s in --debug-select' % it)
567
568         dfs = mapper(dfspec)
569
570       for df in dfs:
571         mutator(df)
572
573   optparser.add_option('-D', '--debug',
574                        nargs=0,
575                        action='callback',
576                        help='enable default debug (to stdout)',
577                        callback= ds_default)
578
579   optparser.add_option('--debug-select',
580                        nargs=1,
581                        type='string',
582                        metavar='[-]DFLAG[+]|[-]+,...',
583                        help=
584 '''enable (`-': disable) each specified DFLAG;
585 `+': do same for all "more interesting" DFLAGSs;
586 just `+': all DFLAGs.
587   DFLAGS: ''' + ' '.join([df.name for df in DBG.iterconstants()]),
588                        action='callback',
589                        callback= ds_select)
590
591   optparser.add_option('-c', '--config',
592                        nargs=1,
593                        type='string',
594                        metavar='CONFIGFILE',
595                        dest='configfile',
596                        action='callback',
597                        callback= oc_config)
598
599   (opts, args) = optparser.parse_args()
600   if len(args): optparser.error('no non-option arguments please')
601
602   if need_defcfg:
603     readconfig('/etc/hippotat/config',   False)
604     readconfig('/etc/hippotat/config.d', False)
605
606   try:
607     (pss, pcs) = _cfg_process_putatives()
608     process_cfg(pss, pcs)
609   except (configparser.Error, ValueError):
610     traceback.print_exc(file=sys.stderr)
611     print('\nInvalid configuration, giving up.', file=sys.stderr)
612     sys.exit(12)
613
614   #print(repr(debug_set), file=sys.stderr)
615
616   log_formatter = twisted.logger.formatEventAsClassicLogText
617   stdout_obs = twisted.logger.FileLogObserver(sys.stdout, log_formatter)
618   stderr_obs = twisted.logger.FileLogObserver(sys.stderr, log_formatter)
619   pred = twisted.logger.LogLevelFilterPredicate(LogLevel.error)
620   stdsomething_obs = twisted.logger.FilteringLogObserver(
621     stderr_obs, [pred], stdout_obs
622   )
623   log_observer = twisted.logger.FilteringLogObserver(
624     stdsomething_obs, [LogNotBoringTwisted()]
625   )
626   #log_observer = stdsomething_obs
627   twisted.logger.globalLogBeginner.beginLoggingTo(
628     [ log_observer, crash_on_critical ]
629     )
630
631 def common_run():
632   log_debug(DBG.INIT, 'entering reactor')
633   if not _crashing: reactor.run()
634   print('CRASHED (end)', file=sys.stderr)