chiark / gitweb /
397dfb2c23459cf4f9a7c17bd4064b63cec2e215
[hippotat.git] / hippotatlib / __init__.py
1 # -*- python -*-
2 #
3 # Hippotat - Asinine IP Over HTTP program
4 # hippotatlib/__init__.py - common library code
5 #
6 # Copyright 2017 Ian Jackson
7 #
8 # GPLv3+
9 #
10 #    This program is free software: you can redistribute it and/or modify
11 #    it under the terms of the GNU General Public License as published by
12 #    the Free Software Foundation, either version 3 of the License, or
13 #    (at your option) any later version.
14 #
15 #    This program is distributed in the hope that it will be useful,
16 #    but WITHOUT ANY WARRANTY; without even the implied warranty of
17 #    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
18 #    GNU General Public License for more details.
19 #
20 #    You should have received a copy of the GNU General Public License
21 #    along with this program, in the file GPLv3.  If not,
22 #    see <http://www.gnu.org/licenses/>.
23
24
25 import signal
26 signal.signal(signal.SIGINT, signal.SIG_DFL)
27
28 import sys
29 import os
30
31 from zope.interface import implementer
32
33 import twisted
34 from twisted.internet import reactor
35 import twisted.internet.endpoints
36 import twisted.logger
37 from twisted.logger import LogLevel
38 import twisted.python.constants
39 from twisted.python.constants import NamedConstant
40
41 import ipaddress
42 from ipaddress import AddressValueError
43
44 from optparse import OptionParser
45 import configparser
46 from configparser import ConfigParser
47 from configparser import NoOptionError
48
49 from functools import partial
50
51 import collections
52 import time
53 import codecs
54 import traceback
55
56 import re as regexp
57
58 import hippotatlib.slip as slip
59
60 class DBG(twisted.python.constants.Names):
61   INIT = NamedConstant()
62   CONFIG = NamedConstant()
63   ROUTE = NamedConstant()
64   DROP = NamedConstant()
65   OWNSOURCE = NamedConstant()
66   FLOW = NamedConstant()
67   HTTP = NamedConstant()
68   TWISTED = NamedConstant()
69   QUEUE = NamedConstant()
70   HTTP_CTRL = NamedConstant()
71   QUEUE_CTRL = NamedConstant()
72   HTTP_FULL = NamedConstant()
73   CTRL_DUMP = NamedConstant()
74   SLIP_FULL = NamedConstant()
75   DATA_COMPLETE = NamedConstant()
76
77 _hex_codec = codecs.getencoder('hex_codec')
78
79 #---------- logging ----------
80
81 org_stderr = sys.stderr
82
83 log = twisted.logger.Logger()
84
85 debug_set = set()
86 debug_def_detail = DBG.HTTP
87
88 def log_debug(dflag, msg, idof=None, d=None):
89   if dflag not in debug_set: return
90   #print('---------------->',repr((dflag, msg, idof, d)), file=sys.stderr)
91   if idof is not None:
92     msg = '[%#x] %s' % (id(idof), msg)
93   if d is not None:
94     trunc = ''
95     if not DBG.DATA_COMPLETE in debug_set:
96       if len(d) > 64:
97         d = d[0:64]
98         trunc = '...'
99     d = _hex_codec(d)[0].decode('ascii')
100     msg += ' ' + d + trunc
101   log.info('{dflag} {msgcore}', dflag=dflag, msgcore=msg)
102
103 def logevent_is_boringtwisted(event):
104   try:
105     if event.get('log_level') != LogLevel.info:
106       return False
107     dflag = event.get('dflag')
108     if dflag is False                            : return False
109     if dflag                         in debug_set: return False
110     if dflag is None and DBG.TWISTED in debug_set: return False
111     return True
112   except Exception:
113     print(traceback.format_exc(), file=org_stderr)
114     return False
115
116 @implementer(twisted.logger.ILogFilterPredicate)
117 class LogNotBoringTwisted:
118   def __call__(self, event):
119     return (
120       twisted.logger.PredicateResult.no
121       if logevent_is_boringtwisted(event) else
122       twisted.logger.PredicateResult.yes
123     )
124
125 #---------- default config ----------
126
127 defcfg = '''
128 [DEFAULT]
129 max_batch_down = 65536
130 max_queue_time = 10
131 target_requests_outstanding = 3
132 http_timeout = 30
133 http_timeout_grace = 5
134 max_requests_outstanding = 6
135 max_batch_up = 4000
136 http_retry = 5
137 port = 80
138 vroutes = ''
139
140 #[server] or [<client>] overrides
141 ipif = userv root ipif %(local)s,%(peer)s,%(mtu)s,slip %(rnets)s
142
143 # relating to virtual network
144 mtu = 1500
145
146 [SERVER]
147 server = SERVER
148 # addrs = 127.0.0.1 ::1
149 # url
150
151 # relating to virtual network
152 vvnetwork = 172.24.230.192
153 # vnetwork = <prefix>/<len>
154 # vadd  r  = <ipaddr>
155 # vrelay   = <ipaddr>
156
157
158 # [<client-ip4-or-ipv6-address>]
159 # password = <password>    # used by both, must match
160
161 [LIMIT]
162 max_batch_down = 262144
163 max_queue_time = 121
164 http_timeout = 121
165 target_requests_outstanding = 10
166 '''
167
168 # these need to be defined here so that they can be imported by import *
169 cfg = ConfigParser(strict=False)
170 optparser = OptionParser()
171
172 _mimetrans = bytes.maketrans(b'-'+slip.esc, slip.esc+b'-')
173 def mime_translate(s):
174   # SLIP-encoded packets cannot contain ESC ESC.
175   # Swap `-' and ESC.  The result cannot contain `--'
176   return s.translate(_mimetrans)
177
178 class ConfigResults:
179   def __init__(self):
180     pass
181   def __repr__(self):
182     return 'ConfigResults('+repr(self.__dict__)+')'
183
184 def log_discard(packet, iface, saddr, daddr, why):
185   log_debug(DBG.DROP,
186             'discarded packet [%s] %s -> %s: %s' % (iface, saddr, daddr, why),
187             d=packet)
188
189 #---------- packet parsing ----------
190
191 def packet_addrs(packet):
192   version = packet[0] >> 4
193   if version == 4:
194     addrlen = 4
195     saddroff = 3*4
196     factory = ipaddress.IPv4Address
197   elif version == 6:
198     addrlen = 16
199     saddroff = 2*4
200     factory = ipaddress.IPv6Address
201   else:
202     raise ValueError('unsupported IP version %d' % version)
203   saddr = factory(packet[ saddroff           : saddroff + addrlen   ])
204   daddr = factory(packet[ saddroff + addrlen : saddroff + addrlen*2 ])
205   return (saddr, daddr)
206
207 #---------- address handling ----------
208
209 def ipaddr(input):
210   try:
211     r = ipaddress.IPv4Address(input)
212   except AddressValueError:
213     r = ipaddress.IPv6Address(input)
214   return r
215
216 def ipnetwork(input):
217   try:
218     r = ipaddress.IPv4Network(input)
219   except NetworkValueError:
220     r = ipaddress.IPv6Network(input)
221   return r
222
223 #---------- ipif (SLIP) subprocess ----------
224
225 class SlipStreamDecoder():
226   def __init__(self, desc, on_packet):
227     self._buffer = b''
228     self._on_packet = on_packet
229     self._desc = desc
230     self._log('__init__')
231
232   def _log(self, msg, **kwargs):
233     log_debug(DBG.SLIP_FULL, 'slip %s: %s' % (self._desc, msg), **kwargs)
234
235   def inputdata(self, data):
236     self._log('inputdata', d=data)
237     data = self._buffer + data
238     self._buffer = b''
239     packets = slip.decode(data, True)
240     self._buffer = packets.pop()
241     for packet in packets:
242       self._maybe_packet(packet)
243     self._log('bufremain', d=self._buffer)
244
245   def _maybe_packet(self, packet):
246     self._log('maybepacket', d=packet)
247     if len(packet):
248       self._on_packet(packet)
249
250   def flush(self):
251     self._log('flush')
252     data = self._buffer
253     self._buffer = b''
254     packets = slip.decode(data)
255     assert(len(packets) == 1)
256     self._maybe_packet(packets[0])
257
258 class _IpifProcessProtocol(twisted.internet.protocol.ProcessProtocol):
259   def __init__(self, router):
260     self._router = router
261     self._decoder = SlipStreamDecoder('ipif', self.slip_on_packet)
262   def connectionMade(self): pass
263   def outReceived(self, data):
264     self._decoder.inputdata(data)
265   def slip_on_packet(self, packet):
266     (saddr, daddr) = packet_addrs(packet)
267     if saddr.is_link_local or daddr.is_link_local:
268       log_discard(packet, 'ipif', saddr, daddr, 'link-local')
269       return
270     self._router(packet, saddr, daddr)
271   def processEnded(self, status):
272     status.raiseException()
273
274 def start_ipif(command, router):
275   ipif = _IpifProcessProtocol(router)
276   reactor.spawnProcess(ipif,
277                        '/bin/sh',['sh','-xc', command],
278                        childFDs={0:'w', 1:'r', 2:2},
279                        env=None)
280   return ipif
281
282 def queue_inbound(ipif, packet):
283   log_debug(DBG.FLOW, "queue_inbound", d=packet)
284   ipif.transport.write(slip.delimiter)
285   ipif.transport.write(slip.encode(packet))
286   ipif.transport.write(slip.delimiter)
287
288 #---------- packet queue ----------
289
290 class PacketQueue():
291   def __init__(self, desc, max_queue_time):
292     self._desc = desc
293     assert(desc + '')
294     self._max_queue_time = max_queue_time
295     self._pq = collections.deque() # packets
296
297   def _log(self, dflag, msg, **kwargs):
298     log_debug(dflag, self._desc+' pq: '+msg, **kwargs)
299
300   def append(self, packet):
301     self._log(DBG.QUEUE, 'append', d=packet)
302     self._pq.append((time.monotonic(), packet))
303
304   def nonempty(self):
305     self._log(DBG.QUEUE, 'nonempty ?')
306     while True:
307       try: (queuetime, packet) = self._pq[0]
308       except IndexError:
309         self._log(DBG.QUEUE, 'nonempty ? empty.')
310         return False
311
312       age = time.monotonic() - queuetime
313       if age > self._max_queue_time:
314         # strip old packets off the front
315         self._log(DBG.QUEUE, 'dropping (old)', d=packet)
316         self._pq.popleft()
317         continue
318
319       self._log(DBG.QUEUE, 'nonempty ? nonempty.')
320       return True
321
322   def process(self, sizequery, moredata, max_batch):
323     # sizequery() should return size of batch so far
324     # moredata(s) should add s to batch
325     self._log(DBG.QUEUE, 'process...')
326     while True:
327       try: (dummy, packet) = self._pq[0]
328       except IndexError:
329         self._log(DBG.QUEUE, 'process... empty')
330         break
331
332       self._log(DBG.QUEUE_CTRL, 'process... packet', d=packet)
333
334       encoded = slip.encode(packet)
335       sofar = sizequery()  
336
337       self._log(DBG.QUEUE_CTRL,
338                 'process... (sofar=%d, max=%d) encoded' % (sofar, max_batch),
339                 d=encoded)
340
341       if sofar > 0:
342         if sofar + len(slip.delimiter) + len(encoded) > max_batch:
343           self._log(DBG.QUEUE_CTRL, 'process... overflow')
344           break
345         moredata(slip.delimiter)
346
347       moredata(encoded)
348       self._pq.popleft()
349
350 #---------- error handling ----------
351
352 _crashing = False
353
354 def crash(err):
355   global _crashing
356   _crashing = True
357   print('========== CRASH ==========', err,
358         '===========================', file=sys.stderr)
359   try: reactor.stop()
360   except twisted.internet.error.ReactorNotRunning: pass
361
362 def crash_on_defer(defer):
363   defer.addErrback(lambda err: crash(err))
364
365 def crash_on_critical(event):
366   if event.get('log_level') >= LogLevel.critical:
367     crash(twisted.logger.formatEvent(event))
368
369 #---------- config processing ----------
370
371 def _cfg_process_putatives():
372   servers = { }
373   clients = { }
374   # maps from abstract object to canonical name for cs's
375
376   def putative(cmap, abstract, canoncs):
377     try:
378       current_canoncs = cmap[abstract]
379     except KeyError:
380       pass
381     else:
382       assert(current_canoncs == canoncs)
383     cmap[abstract] = canoncs
384
385   server_pat = r'[-.0-9A-Za-z]+'
386   client_pat = r'[.:0-9a-f]+'
387   server_re = regexp.compile(server_pat)
388   serverclient_re = regexp.compile(server_pat + r' ' + client_pat)
389
390   for cs in cfg.sections():
391     if cs == 'LIMIT':
392       # plan A "[LIMIT]"
393       continue
394
395     try:
396       # plan B "[<client>]" part 1
397       ci = ipaddr(cs)
398     except AddressValueError:
399
400       if server_re.fullmatch(cs):
401         # plan C "[<servername>]"
402         putative(servers, cs, cs)
403         continue
404
405       if serverclient_re.fullmatch(cs):
406         # plan D "[<servername> <client>]" part 1
407         (pss,pcs) = cs.split(' ')
408
409         if pcs == 'LIMIT':
410           # plan E "[<servername> LIMIT]"
411           continue
412
413         try:
414           # plan D "[<servername> <client>]" part 2
415           ci = ipaddr(pc)
416         except AddressValueError:
417           # plan F "[<some thing we do not understand>]"
418           # well, we ignore this
419           print('warning: ignoring config section %s' % cs, file=sys.stderr)
420           continue
421
422         else: # no AddressValueError
423           # plan D "[<servername> <client]" part 3
424           putative(clients, ci, pcs)
425           putative(servers, pss, pss)
426           continue
427
428     else: # no AddressValueError
429       # plan B "[<client>" part 2
430       putative(clients, ci, cs)
431       continue
432
433   return (servers, clients)
434
435 def cfg_process_common(c, ss):
436   c.mtu = cfg.getint(ss, 'mtu')
437
438 def cfg_process_saddrs(c, ss):
439   class ServerAddr():
440     def __init__(self, port, addrspec):
441       self.port = port
442       # also self.addr
443       try:
444         self.addr = ipaddress.IPv4Address(addrspec)
445         self._endpointfactory = twisted.internet.endpoints.TCP4ServerEndpoint
446         self._inurl = b'%s'
447       except AddressValueError:
448         self.addr = ipaddress.IPv6Address(addrspec)
449         self._endpointfactory = twisted.internet.endpoints.TCP6ServerEndpoint
450         self._inurl = b'[%s]'
451     def make_endpoint(self):
452       return self._endpointfactory(reactor, self.port,
453                                    interface= '%s' % self.addr)
454     def url(self):
455       url = b'http://' + (self._inurl % str(self.addr).encode('ascii'))
456       if self.port != 80: url += b':%d' % self.port
457       url += b'/'
458       return url
459     def __repr__(self):
460       return 'ServerAddr'+repr((self.port,self.addr))
461
462   c.port = cfg.getint(ss,'port')
463   c.saddrs = [ ]
464   for addrspec in cfg.get(ss, 'addrs').split():
465     sa = ServerAddr(c.port, addrspec)
466     c.saddrs.append(sa)
467
468 def cfg_process_vnetwork(c, ss):
469   c.vnetwork = ipnetwork(cfg.get(ss,'vnetwork'))
470   if c.vnetwork.num_addresses < 3 + 2:
471     raise ValueError('vnetwork needs at least 2^3 addresses')
472
473 def cfg_process_vaddr(c, ss):
474   try:
475     c.vaddr = cfg.get(ss,'vaddr')
476   except NoOptionError:
477     cfg_process_vnetwork(c, ss)
478     c.vaddr = next(c.vnetwork.hosts())
479
480 def cfg_search_section(key,sections):
481   for section in sections:
482     if cfg.has_option(section, key):
483       return section
484   raise NoOptionError(key, repr(sections))
485
486 def cfg_search(getter,key,sections):
487   section = cfg_search_section(key,sections)
488   return getter(section, key)
489
490 def cfg_process_client_limited(cc,ss,sections,key):
491   val = cfg_search(cfg.getint, key, sections)
492   lim = cfg_search(cfg.getint, key, ['%s LIMIT' % ss, 'LIMIT'])
493   cc.__dict__[key] = min(val,lim)
494
495 def cfg_process_client_common(cc,ss,cs,ci):
496   # returns sections to search in, iff password is defined, otherwise None
497   cc.ci = ci
498
499   sections = ['%s %s' % (ss,cs),
500               cs,
501               ss,
502               'DEFAULT']
503
504   try: pwsection = cfg_search_section('password', sections)
505   except NoOptionError: return None
506     
507   pw = cfg.get(pwsection, 'password')
508   cc.password = pw.encode('utf-8')
509
510   cfg_process_client_limited(cc,ss,sections,'target_requests_outstanding')
511   cfg_process_client_limited(cc,ss,sections,'http_timeout')
512
513   return sections
514
515 def cfg_process_ipif(c, sections, varmap):
516   for d, s in varmap:
517     try: v = getattr(c, s)
518     except AttributeError: continue
519     setattr(c, d, v)
520
521   #print('CFGIPIF',repr((varmap, sections, c.__dict__)),file=sys.stderr)
522
523   section = cfg_search_section('ipif', sections)
524   c.ipif_command = cfg.get(section,'ipif', vars=c.__dict__)
525
526 #---------- startup ----------
527
528 def log_debug_config(m):
529   if not DBG.CONFIG in debug_set: return
530   print('DBG.CONFIG:', m)
531
532 def common_startup(process_cfg):
533   # calls process_cfg(putative_clients, putative_servers)
534
535   # ConfigParser hates #-comments after values
536   trailingcomments_re = regexp.compile(r'#.*')
537   cfg.read_string(trailingcomments_re.sub('', defcfg))
538   need_defcfg = True
539
540   def readconfig(pathname, mandatory=True):
541     def log(m, p=pathname):
542       if not DBG.CONFIG in debug_set: return
543       log_debug_config('%s: %s' % (m, pathname))
544
545     try:
546       files = os.listdir(pathname)
547
548     except FileNotFoundError:
549       if mandatory: raise
550       log('skipped')
551       return
552
553     except NotADirectoryError:
554       cfg.read(pathname)
555       log('read file')
556       return
557
558     # is a directory
559     log('directory')
560     re = regexp.compile('[^-A-Za-z0-9_]')
561     for f in os.listdir(pathname):
562       if re.search(f): continue
563       subpath = pathname + '/' + f
564       try:
565         os.stat(subpath)
566       except FileNotFoundError:
567         log('entry skipped', subpath)
568         continue
569       cfg.read(subpath)
570       log('entry read', subpath)
571       
572   def oc_config(od,os, value, op):
573     nonlocal need_defcfg
574     need_defcfg = False
575     readconfig(value)
576
577   def oc_extra_config(od,os, value, op):
578     readconfig(value)
579
580   def read_defconfig():
581     readconfig('/etc/hippotat/config.d', False)
582     readconfig('/etc/hippotat/passwords.d', False)
583     readconfig('/etc/hippotat/master.cfg',   False)
584
585   def oc_defconfig(od,os, value, op):
586     nonlocal need_defcfg
587     need_defcfg = False
588     read_defconfig(value)
589
590   def dfs_less_detailed(dl):
591     return [df for df in DBG.iterconstants() if df <= dl]
592
593   def ds_default(od,os,dl,op):
594     global debug_set
595     debug_set.clear
596     debug_set |= set(dfs_less_detailed(debug_def_detail))
597
598   def ds_select(od,os, spec, op):
599     for it in spec.split(','):
600
601       if it.startswith('-'):
602         mutator = debug_set.discard
603         it = it[1:]
604       else:
605         mutator = debug_set.add
606
607       if it == '+':
608         dfs = DBG.iterconstants()
609
610       else:
611         if it.endswith('+'):
612           mapper = dfs_less_detailed
613           it = it[0:len(it)-1]
614         else:
615           mapper = lambda x: [x]
616
617           try:
618             dfspec = DBG.lookupByName(it)
619           except ValueError:
620             optparser.error('unknown debug flag %s in --debug-select' % it)
621
622         dfs = mapper(dfspec)
623
624       for df in dfs:
625         mutator(df)
626
627   optparser.add_option('-D', '--debug',
628                        nargs=0,
629                        action='callback',
630                        help='enable default debug (to stdout)',
631                        callback= ds_default)
632
633   optparser.add_option('--debug-select',
634                        nargs=1,
635                        type='string',
636                        metavar='[-]DFLAG[+]|[-]+,...',
637                        help=
638 '''enable (`-': disable) each specified DFLAG;
639 `+': do same for all "more interesting" DFLAGSs;
640 just `+': all DFLAGs.
641   DFLAGS: ''' + ' '.join([df.name for df in DBG.iterconstants()]),
642                        action='callback',
643                        callback= ds_select)
644
645   optparser.add_option('-c', '--config',
646                        nargs=1,
647                        type='string',
648                        metavar='CONFIGFILE',
649                        dest='configfile',
650                        action='callback',
651                        callback= oc_config)
652
653   optparser.add_option('--extra-config',
654                        nargs=1,
655                        type='string',
656                        metavar='CONFIGFILE',
657                        dest='configfile',
658                        action='callback',
659                        callback= oc_extra_config)
660
661   optparser.add_option('--default-config',
662                        action='callback',
663                        callback= oc_defconfig)
664
665   (opts, args) = optparser.parse_args()
666   if len(args): optparser.error('no non-option arguments please')
667
668   if need_defcfg:
669     read_defconfig()
670
671   try:
672     (pss, pcs) = _cfg_process_putatives()
673     process_cfg(opts, pss, pcs)
674   except (configparser.Error, ValueError):
675     traceback.print_exc(file=sys.stderr)
676     print('\nInvalid configuration, giving up.', file=sys.stderr)
677     sys.exit(12)
678
679
680   #print('X', debug_set, file=sys.stderr)
681
682   log_formatter = twisted.logger.formatEventAsClassicLogText
683   stdout_obs = twisted.logger.FileLogObserver(sys.stdout, log_formatter)
684   stderr_obs = twisted.logger.FileLogObserver(sys.stderr, log_formatter)
685   pred = twisted.logger.LogLevelFilterPredicate(LogLevel.error)
686   stdsomething_obs = twisted.logger.FilteringLogObserver(
687     stderr_obs, [pred], stdout_obs
688   )
689   global file_log_observer
690   file_log_observer = twisted.logger.FilteringLogObserver(
691     stdsomething_obs, [LogNotBoringTwisted()]
692   )
693   #log_observer = stdsomething_obs
694   twisted.logger.globalLogBeginner.beginLoggingTo(
695     [ file_log_observer, crash_on_critical ]
696     )
697
698 def common_run():
699   log_debug(DBG.INIT, 'entering reactor')
700   if not _crashing: reactor.run()
701   print('ENDED', file=sys.stderr)
702   sys.exit(16)