chiark / gitweb /
ab404fab9ac0188e955528e2764bde428eccd1de
[hippotat.git] / hippotatlib / __init__.py
1 # -*- python -*-
2 #
3 # Hippotat - Asinine IP Over HTTP program
4 # hippotatlib/__init__.py - common library code
5 #
6 # Copyright 2017 Ian Jackson
7 #
8 # GPLv3+
9 #
10 #    This program is free software: you can redistribute it and/or modify
11 #    it under the terms of the GNU General Public License as published by
12 #    the Free Software Foundation, either version 3 of the License, or
13 #    (at your option) any later version.
14 #
15 #    This program is distributed in the hope that it will be useful,
16 #    but WITHOUT ANY WARRANTY; without even the implied warranty of
17 #    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
18 #    GNU General Public License for more details.
19 #
20 #    You should have received a copy of the GNU General Public License
21 #    along with this program, in the file GPLv3.  If not,
22 #    see <http://www.gnu.org/licenses/>.
23
24
25 import signal
26 signal.signal(signal.SIGINT, signal.SIG_DFL)
27
28 import sys
29 import os
30
31 from zope.interface import implementer
32
33 import twisted
34 from twisted.internet import reactor
35 import twisted.internet.endpoints
36 import twisted.logger
37 from twisted.logger import LogLevel
38 import twisted.python.constants
39 from twisted.python.constants import NamedConstant
40
41 import ipaddress
42 from ipaddress import AddressValueError
43
44 from optparse import OptionParser
45 import configparser
46 from configparser import ConfigParser
47 from configparser import NoOptionError
48
49 from functools import partial
50
51 import collections
52 import time
53 import codecs
54 import traceback
55
56 import re as regexp
57
58 import hippotatlib.slip as slip
59
60 class DBG(twisted.python.constants.Names):
61   INIT = NamedConstant()
62   CONFIG = NamedConstant()
63   ROUTE = NamedConstant()
64   DROP = NamedConstant()
65   OWNSOURCE = NamedConstant()
66   FLOW = NamedConstant()
67   HTTP = NamedConstant()
68   TWISTED = NamedConstant()
69   QUEUE = NamedConstant()
70   HTTP_CTRL = NamedConstant()
71   QUEUE_CTRL = NamedConstant()
72   HTTP_FULL = NamedConstant()
73   CTRL_DUMP = NamedConstant()
74   SLIP_FULL = NamedConstant()
75   DATA_COMPLETE = NamedConstant()
76
77 _hex_codec = codecs.getencoder('hex_codec')
78
79 #---------- logging ----------
80
81 org_stderr = sys.stderr
82
83 log = twisted.logger.Logger()
84
85 debug_set = set()
86 debug_def_detail = DBG.HTTP
87
88 def log_debug(dflag, msg, idof=None, d=None):
89   if dflag not in debug_set: return
90   #print('---------------->',repr((dflag, msg, idof, d)), file=sys.stderr)
91   if idof is not None:
92     msg = '[%#x] %s' % (id(idof), msg)
93   if d is not None:
94     trunc = ''
95     if not DBG.DATA_COMPLETE in debug_set:
96       if len(d) > 64:
97         d = d[0:64]
98         trunc = '...'
99     d = _hex_codec(d)[0].decode('ascii')
100     msg += ' ' + d + trunc
101   log.info('{dflag} {msgcore}', dflag=dflag, msgcore=msg)
102
103 def logevent_is_boringtwisted(event):
104   try:
105     if event.get('log_level') != LogLevel.info:
106       return False
107     dflag = event.get('dflag')
108     if dflag is False                            : return False
109     if dflag                         in debug_set: return False
110     if dflag is None and DBG.TWISTED in debug_set: return False
111     return True
112   except Exception:
113     print(traceback.format_exc(), file=org_stderr)
114     return False
115
116 @implementer(twisted.logger.ILogFilterPredicate)
117 class LogNotBoringTwisted:
118   def __call__(self, event):
119     return (
120       twisted.logger.PredicateResult.no
121       if logevent_is_boringtwisted(event) else
122       twisted.logger.PredicateResult.yes
123     )
124
125 #---------- default config ----------
126
127 defcfg = '''
128 [DEFAULT]
129 max_batch_down = 65536
130 max_queue_time = 10
131 target_requests_outstanding = 3
132 http_timeout = 30
133 http_timeout_grace = 5
134 max_requests_outstanding = 6
135 max_batch_up = 4000
136 http_retry = 5
137 port = 80
138 vroutes = ''
139
140 #[server] or [<client>] overrides
141 ipif = userv root ipif %(local)s,%(peer)s,%(mtu)s,slip %(rnets)s
142
143 # relating to virtual network
144 mtu = 1500
145
146 [SERVER]
147 server = SERVER
148 # addrs = 127.0.0.1 ::1
149 # url
150
151 # relating to virtual network
152 vvnetwork = 172.24.230.192
153 # vnetwork = <prefix>/<len>
154 # vadd  r  = <ipaddr>
155 # vrelay   = <ipaddr>
156
157
158 # [<client-ip4-or-ipv6-address>]
159 # password = <password>    # used by both, must match
160
161 [LIMIT]
162 max_batch_down = 262144
163 max_queue_time = 121
164 http_timeout = 121
165 target_requests_outstanding = 10
166 '''
167
168 # these need to be defined here so that they can be imported by import *
169 cfg = ConfigParser(strict=False)
170 optparser = OptionParser()
171
172 _mimetrans = bytes.maketrans(b'-'+slip.esc, slip.esc+b'-')
173 def mime_translate(s):
174   # SLIP-encoded packets cannot contain ESC ESC.
175   # Swap `-' and ESC.  The result cannot contain `--'
176   return s.translate(_mimetrans)
177
178 class ConfigResults:
179   def __init__(self):
180     pass
181   def __repr__(self):
182     return 'ConfigResults('+repr(self.__dict__)+')'
183
184 def log_discard(packet, iface, saddr, daddr, why):
185   log_debug(DBG.DROP,
186             'discarded packet [%s] %s -> %s: %s' % (iface, saddr, daddr, why),
187             d=packet)
188
189 #---------- packet parsing ----------
190
191 def packet_addrs(packet):
192   version = packet[0] >> 4
193   if version == 4:
194     addrlen = 4
195     saddroff = 3*4
196     factory = ipaddress.IPv4Address
197   elif version == 6:
198     addrlen = 16
199     saddroff = 2*4
200     factory = ipaddress.IPv6Address
201   else:
202     raise ValueError('unsupported IP version %d' % version)
203   saddr = factory(packet[ saddroff           : saddroff + addrlen   ])
204   daddr = factory(packet[ saddroff + addrlen : saddroff + addrlen*2 ])
205   return (saddr, daddr)
206
207 #---------- address handling ----------
208
209 def ipaddr(input):
210   try:
211     r = ipaddress.IPv4Address(input)
212   except AddressValueError:
213     r = ipaddress.IPv6Address(input)
214   return r
215
216 def ipnetwork(input):
217   try:
218     r = ipaddress.IPv4Network(input)
219   except NetworkValueError:
220     r = ipaddress.IPv6Network(input)
221   return r
222
223 #---------- ipif (SLIP) subprocess ----------
224
225 class SlipStreamDecoder():
226   def __init__(self, desc, on_packet):
227     self._buffer = b''
228     self._on_packet = on_packet
229     self._desc = desc
230     self._log('__init__')
231
232   def _log(self, msg, **kwargs):
233     log_debug(DBG.SLIP_FULL, 'slip %s: %s' % (self._desc, msg), **kwargs)
234
235   def inputdata(self, data):
236     self._log('inputdata', d=data)
237     data = self._buffer + data
238     self._buffer = b''
239     packets = slip.decode(data, True)
240     self._buffer = packets.pop()
241     for packet in packets:
242       self._maybe_packet(packet)
243     self._log('bufremain', d=self._buffer)
244
245   def _maybe_packet(self, packet):
246     self._log('maybepacket', d=packet)
247     if len(packet):
248       self._on_packet(packet)
249
250   def flush(self):
251     self._log('flush')
252     data = self._buffer
253     self._buffer = b''
254     packets = slip.decode(data)
255     assert(len(packets) == 1)
256     self._maybe_packet(packets[0])
257
258 class _IpifProcessProtocol(twisted.internet.protocol.ProcessProtocol):
259   def __init__(self, router):
260     self._router = router
261     self._decoder = SlipStreamDecoder('ipif', self.slip_on_packet)
262   def connectionMade(self): pass
263   def outReceived(self, data):
264     self._decoder.inputdata(data)
265   def slip_on_packet(self, packet):
266     (saddr, daddr) = packet_addrs(packet)
267     if saddr.is_link_local or daddr.is_link_local:
268       log_discard(packet, 'ipif', saddr, daddr, 'link-local')
269       return
270     self._router(packet, saddr, daddr)
271   def processEnded(self, status):
272     status.raiseException()
273
274 def start_ipif(command, router):
275   ipif = _IpifProcessProtocol(router)
276   reactor.spawnProcess(ipif,
277                        '/bin/sh',['sh','-xc', command],
278                        childFDs={0:'w', 1:'r', 2:2},
279                        env=None)
280   return ipif
281
282 def queue_inbound(ipif, packet):
283   log_debug(DBG.FLOW, "queue_inbound", d=packet)
284   ipif.transport.write(slip.delimiter)
285   ipif.transport.write(slip.encode(packet))
286   ipif.transport.write(slip.delimiter)
287
288 #---------- packet queue ----------
289
290 class PacketQueue():
291   def __init__(self, desc, max_queue_time):
292     self._desc = desc
293     assert(desc + '')
294     self._max_queue_time = max_queue_time
295     self._pq = collections.deque() # packets
296
297   def _log(self, dflag, msg, **kwargs):
298     log_debug(dflag, self._desc+' pq: '+msg, **kwargs)
299
300   def append(self, packet):
301     self._log(DBG.QUEUE, 'append', d=packet)
302     self._pq.append((time.monotonic(), packet))
303
304   def nonempty(self):
305     self._log(DBG.QUEUE, 'nonempty ?')
306     while True:
307       try: (queuetime, packet) = self._pq[0]
308       except IndexError:
309         self._log(DBG.QUEUE, 'nonempty ? empty.')
310         return False
311
312       age = time.monotonic() - queuetime
313       if age > self._max_queue_time:
314         # strip old packets off the front
315         self._log(DBG.QUEUE, 'dropping (old)', d=packet)
316         self._pq.popleft()
317         continue
318
319       self._log(DBG.QUEUE, 'nonempty ? nonempty.')
320       return True
321
322   def process(self, sizequery, moredata, max_batch):
323     # sizequery() should return size of batch so far
324     # moredata(s) should add s to batch
325     self._log(DBG.QUEUE, 'process...')
326     while True:
327       try: (dummy, packet) = self._pq[0]
328       except IndexError:
329         self._log(DBG.QUEUE, 'process... empty')
330         break
331
332       self._log(DBG.QUEUE_CTRL, 'process... packet', d=packet)
333
334       encoded = slip.encode(packet)
335       sofar = sizequery()  
336
337       self._log(DBG.QUEUE_CTRL,
338                 'process... (sofar=%d, max=%d) encoded' % (sofar, max_batch),
339                 d=encoded)
340
341       if sofar > 0:
342         if sofar + len(slip.delimiter) + len(encoded) > max_batch:
343           self._log(DBG.QUEUE_CTRL, 'process... overflow')
344           break
345         moredata(slip.delimiter)
346
347       moredata(encoded)
348       self._pq.popleft()
349
350 #---------- error handling ----------
351
352 _crashing = False
353
354 def crash(err):
355   global _crashing
356   _crashing = True
357   print('========== CRASH ==========', err,
358         '===========================', file=sys.stderr)
359   try: reactor.stop()
360   except twisted.internet.error.ReactorNotRunning: pass
361
362 def crash_on_defer(defer):
363   defer.addErrback(lambda err: crash(err))
364
365 def crash_on_critical(event):
366   if event.get('log_level') >= LogLevel.critical:
367     crash(twisted.logger.formatEvent(event))
368
369 #---------- config processing ----------
370
371 def _cfg_process_putatives():
372   servers = { }
373   clients = { }
374   # maps from abstract object to canonical name for cs's
375
376   def putative(cmap, abstract, canoncs):
377     try:
378       current_canoncs = cmap[abstract]
379     except KeyError:
380       pass
381     else:
382       assert(current_canoncs == canoncs)
383     cmap[abstract] = canoncs
384
385   server_pat = r'[-.0-9A-Za-z]+'
386   client_pat = r'[.:0-9a-f]+'
387   server_re = regexp.compile(server_pat)
388   serverclient_re = regexp.compile(server_pat + r' ' + client_pat)
389
390   for cs in cfg.sections():
391     if cs == 'LIMIT':
392       # plan A "[LIMIT]"
393       continue
394
395     try:
396       # plan B "[<client>]" part 1
397       ci = ipaddr(cs)
398     except AddressValueError:
399
400       if server_re.fullmatch(cs):
401         # plan C "[<servername>]"
402         putative(servers, cs, cs)
403         continue
404
405       if serverclient_re.fullmatch(cs):
406         # plan D "[<servername> <client>]" part 1
407         (pss,pcs) = cs.split(' ')
408
409         if pcs == 'LIMIT':
410           # plan E "[<servername> LIMIT]"
411           continue
412
413         try:
414           # plan D "[<servername> <client>]" part 2
415           ci = ipaddr(pc)
416         except AddressValueError:
417           # plan F "[<some thing we do not understand>]"
418           # well, we ignore this
419           print('warning: ignoring config section %s' % cs, file=sys.stderr)
420           continue
421
422         else: # no AddressValueError
423           # plan D "[<servername> <client]" part 3
424           putative(clients, ci, pcs)
425           putative(servers, pss, pss)
426           continue
427
428     else: # no AddressValueError
429       # plan B "[<client>" part 2
430       putative(clients, ci, cs)
431       continue
432
433   return (servers, clients)
434
435 def cfg_process_common(c, ss):
436   c.mtu = cfg.getint(ss, 'mtu')
437
438 def cfg_process_saddrs(c, ss):
439   class ServerAddr():
440     def __init__(self, port, addrspec):
441       self.port = port
442       # also self.addr
443       try:
444         self.addr = ipaddress.IPv4Address(addrspec)
445         self._endpointfactory = twisted.internet.endpoints.TCP4ServerEndpoint
446         self._inurl = b'%s'
447       except AddressValueError:
448         self.addr = ipaddress.IPv6Address(addrspec)
449         self._endpointfactory = twisted.internet.endpoints.TCP6ServerEndpoint
450         self._inurl = b'[%s]'
451     def make_endpoint(self):
452       return self._endpointfactory(reactor, self.port, self.addr)
453     def url(self):
454       url = b'http://' + (self._inurl % str(self.addr).encode('ascii'))
455       if self.port != 80: url += b':%d' % self.port
456       url += b'/'
457       return url
458
459   c.port = cfg.getint(ss,'port')
460   c.saddrs = [ ]
461   for addrspec in cfg.get(ss, 'addrs').split():
462     sa = ServerAddr(c.port, addrspec)
463     c.saddrs.append(sa)
464
465 def cfg_process_vnetwork(c, ss):
466   c.vnetwork = ipnetwork(cfg.get(ss,'vnetwork'))
467   if c.vnetwork.num_addresses < 3 + 2:
468     raise ValueError('vnetwork needs at least 2^3 addresses')
469
470 def cfg_process_vaddr(c, ss):
471   try:
472     c.vaddr = cfg.get(ss,'vaddr')
473   except NoOptionError:
474     cfg_process_vnetwork(c, ss)
475     c.vaddr = next(c.vnetwork.hosts())
476
477 def cfg_search_section(key,sections):
478   for section in sections:
479     if cfg.has_option(section, key):
480       return section
481   raise NoOptionError(key, repr(sections))
482
483 def cfg_search(getter,key,sections):
484   section = cfg_search_section(key,sections)
485   return getter(section, key)
486
487 def cfg_process_client_limited(cc,ss,sections,key):
488   val = cfg_search(cfg.getint, key, sections)
489   lim = cfg_search(cfg.getint, key, ['%s LIMIT' % ss, 'LIMIT'])
490   cc.__dict__[key] = min(val,lim)
491
492 def cfg_process_client_common(cc,ss,cs,ci):
493   # returns sections to search in, iff password is defined, otherwise None
494   cc.ci = ci
495
496   sections = ['%s %s' % (ss,cs),
497               cs,
498               ss,
499               'DEFAULT']
500
501   try: pwsection = cfg_search_section('password', sections)
502   except NoOptionError: return None
503     
504   pw = cfg.get(pwsection, 'password')
505   cc.password = pw.encode('utf-8')
506
507   cfg_process_client_limited(cc,ss,sections,'target_requests_outstanding')
508   cfg_process_client_limited(cc,ss,sections,'http_timeout')
509
510   return sections
511
512 def cfg_process_ipif(c, sections, varmap):
513   for d, s in varmap:
514     try: v = getattr(c, s)
515     except AttributeError: continue
516     setattr(c, d, v)
517
518   #print('CFGIPIF',repr((varmap, sections, c.__dict__)),file=sys.stderr)
519
520   section = cfg_search_section('ipif', sections)
521   c.ipif_command = cfg.get(section,'ipif', vars=c.__dict__)
522
523 #---------- startup ----------
524
525 def common_startup(process_cfg):
526   # calls process_cfg(putative_clients, putative_servers)
527
528   # ConfigParser hates #-comments after values
529   trailingcomments_re = regexp.compile(r'#.*')
530   cfg.read_string(trailingcomments_re.sub('', defcfg))
531   need_defcfg = True
532
533   def readconfig(pathname, mandatory=True):
534     def log(m, p=pathname):
535       if not DBG.CONFIG in debug_set: return
536       print('DBG.CONFIG: %s: %s' % (m, pathname))
537
538     try:
539       files = os.listdir(pathname)
540
541     except FileNotFoundError:
542       if mandatory: raise
543       log('skipped')
544       return
545
546     except NotADirectoryError:
547       cfg.read(pathname)
548       log('read file')
549       return
550
551     # is a directory
552     log('directory')
553     re = regexp.compile('[^-A-Za-z0-9_]')
554     for f in os.listdir(pathname):
555       if re.search(f): continue
556       subpath = pathname + '/' + f
557       try:
558         os.stat(subpath)
559       except FileNotFoundError:
560         log('entry skipped', subpath)
561         continue
562       cfg.read(subpath)
563       log('entry read', subpath)
564       
565   def oc_config(od,os, value, op):
566     nonlocal need_defcfg
567     need_defcfg = False
568     readconfig(value)
569
570   def oc_extra_config(od,os, value, op):
571     readconfig(value)
572
573   def read_defconfig():
574     readconfig('/etc/hippotat/config.d', False)
575     readconfig('/etc/hippotat/passwords.d', False)
576     readconfig('/etc/hippotat/master.cfg',   False)
577
578   def oc_defconfig(od,os, value, op):
579     nonlocal need_defcfg
580     need_defcfg = False
581     read_defconfig(value)
582
583   def dfs_less_detailed(dl):
584     return [df for df in DBG.iterconstants() if df <= dl]
585
586   def ds_default(od,os,dl,op):
587     global debug_set
588     debug_set.clear
589     debug_set |= set(dfs_less_detailed(debug_def_detail))
590
591   def ds_select(od,os, spec, op):
592     for it in spec.split(','):
593
594       if it.startswith('-'):
595         mutator = debug_set.discard
596         it = it[1:]
597       else:
598         mutator = debug_set.add
599
600       if it == '+':
601         dfs = DBG.iterconstants()
602
603       else:
604         if it.endswith('+'):
605           mapper = dfs_less_detailed
606           it = it[0:len(it)-1]
607         else:
608           mapper = lambda x: [x]
609
610           try:
611             dfspec = DBG.lookupByName(it)
612           except ValueError:
613             optparser.error('unknown debug flag %s in --debug-select' % it)
614
615         dfs = mapper(dfspec)
616
617       for df in dfs:
618         mutator(df)
619
620   optparser.add_option('-D', '--debug',
621                        nargs=0,
622                        action='callback',
623                        help='enable default debug (to stdout)',
624                        callback= ds_default)
625
626   optparser.add_option('--debug-select',
627                        nargs=1,
628                        type='string',
629                        metavar='[-]DFLAG[+]|[-]+,...',
630                        help=
631 '''enable (`-': disable) each specified DFLAG;
632 `+': do same for all "more interesting" DFLAGSs;
633 just `+': all DFLAGs.
634   DFLAGS: ''' + ' '.join([df.name for df in DBG.iterconstants()]),
635                        action='callback',
636                        callback= ds_select)
637
638   optparser.add_option('-c', '--config',
639                        nargs=1,
640                        type='string',
641                        metavar='CONFIGFILE',
642                        dest='configfile',
643                        action='callback',
644                        callback= oc_config)
645
646   optparser.add_option('--extra-config',
647                        nargs=1,
648                        type='string',
649                        metavar='CONFIGFILE',
650                        dest='configfile',
651                        action='callback',
652                        callback= oc_extra_config)
653
654   optparser.add_option('--default-config',
655                        action='callback',
656                        callback= oc_defconfig)
657
658   (opts, args) = optparser.parse_args()
659   if len(args): optparser.error('no non-option arguments please')
660
661   if need_defcfg:
662     read_defconfig()
663
664   try:
665     (pss, pcs) = _cfg_process_putatives()
666     process_cfg(opts, pss, pcs)
667   except (configparser.Error, ValueError):
668     traceback.print_exc(file=sys.stderr)
669     print('\nInvalid configuration, giving up.', file=sys.stderr)
670     sys.exit(12)
671
672
673   #print('X', debug_set, file=sys.stderr)
674
675   log_formatter = twisted.logger.formatEventAsClassicLogText
676   stdout_obs = twisted.logger.FileLogObserver(sys.stdout, log_formatter)
677   stderr_obs = twisted.logger.FileLogObserver(sys.stderr, log_formatter)
678   pred = twisted.logger.LogLevelFilterPredicate(LogLevel.error)
679   stdsomething_obs = twisted.logger.FilteringLogObserver(
680     stderr_obs, [pred], stdout_obs
681   )
682   global file_log_observer
683   file_log_observer = twisted.logger.FilteringLogObserver(
684     stdsomething_obs, [LogNotBoringTwisted()]
685   )
686   #log_observer = stdsomething_obs
687   twisted.logger.globalLogBeginner.beginLoggingTo(
688     [ file_log_observer, crash_on_critical ]
689     )
690
691 def common_run():
692   log_debug(DBG.INIT, 'entering reactor')
693   if not _crashing: reactor.run()
694   print('ENDED', file=sys.stderr)