chiark / gitweb /
config dir reading fix
[hippotat.git] / hippotatlib / __init__.py
1 # -*- python -*-
2 #
3 # Hippotat - Asinine IP Over HTTP program
4 # hippotatlib/__init__.py - common library code
5 #
6 # Copyright 2017 Ian Jackson
7 #
8 # GPLv3+
9 #
10 #    This program is free software: you can redistribute it and/or modify
11 #    it under the terms of the GNU General Public License as published by
12 #    the Free Software Foundation, either version 3 of the License, or
13 #    (at your option) any later version.
14 #
15 #    This program is distributed in the hope that it will be useful,
16 #    but WITHOUT ANY WARRANTY; without even the implied warranty of
17 #    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
18 #    GNU General Public License for more details.
19 #
20 #    You should have received a copy of the GNU General Public License
21 #    along with this program, in the file GPLv3.  If not,
22 #    see <http://www.gnu.org/licenses/>.
23
24
25 import signal
26 signal.signal(signal.SIGINT, signal.SIG_DFL)
27
28 import sys
29 import os
30
31 from zope.interface import implementer
32
33 import twisted
34 from twisted.internet import reactor
35 import twisted.internet.endpoints
36 import twisted.logger
37 from twisted.logger import LogLevel
38 import twisted.python.constants
39 from twisted.python.constants import NamedConstant
40
41 import ipaddress
42 from ipaddress import AddressValueError
43
44 from optparse import OptionParser
45 import configparser
46 from configparser import ConfigParser
47 from configparser import NoOptionError
48
49 from functools import partial
50
51 import collections
52 import time
53 import codecs
54 import traceback
55
56 import re as regexp
57
58 import hippotatlib.slip as slip
59
60 class DBG(twisted.python.constants.Names):
61   INIT = NamedConstant()
62   CONFIG = NamedConstant()
63   ROUTE = NamedConstant()
64   DROP = NamedConstant()
65   OWNSOURCE = NamedConstant()
66   FLOW = NamedConstant()
67   HTTP = NamedConstant()
68   TWISTED = NamedConstant()
69   QUEUE = NamedConstant()
70   HTTP_CTRL = NamedConstant()
71   QUEUE_CTRL = NamedConstant()
72   HTTP_FULL = NamedConstant()
73   CTRL_DUMP = NamedConstant()
74   SLIP_FULL = NamedConstant()
75   DATA_COMPLETE = NamedConstant()
76
77 _hex_codec = codecs.getencoder('hex_codec')
78
79 #---------- logging ----------
80
81 org_stderr = sys.stderr
82
83 log = twisted.logger.Logger()
84
85 debug_set = set()
86 debug_def_detail = DBG.HTTP
87
88 def log_debug(dflag, msg, idof=None, d=None):
89   if dflag not in debug_set: return
90   #print('---------------->',repr((dflag, msg, idof, d)), file=sys.stderr)
91   if idof is not None:
92     msg = '[%#x] %s' % (id(idof), msg)
93   if d is not None:
94     trunc = ''
95     if not DBG.DATA_COMPLETE in debug_set:
96       if len(d) > 64:
97         d = d[0:64]
98         trunc = '...'
99     d = _hex_codec(d)[0].decode('ascii')
100     msg += ' ' + d + trunc
101   log.info('{dflag} {msgcore}', dflag=dflag, msgcore=msg)
102
103 @implementer(twisted.logger.ILogFilterPredicate)
104 class LogNotBoringTwisted:
105   def __call__(self, event):
106     yes = twisted.logger.PredicateResult.yes
107     no  = twisted.logger.PredicateResult.no
108     try:
109       if event.get('log_level') != LogLevel.info:
110         return yes
111       dflag = event.get('dflag')
112       if dflag is False                            : return yes
113       if dflag                         in debug_set: return yes
114       if dflag is None and DBG.TWISTED in debug_set: return yes
115       return no
116     except Exception:
117       print(traceback.format_exc(), file=org_stderr)
118       return yes
119
120 #---------- default config ----------
121
122 defcfg = '''
123 [DEFAULT]
124 max_batch_down = 65536
125 max_queue_time = 10
126 target_requests_outstanding = 3
127 http_timeout = 30
128 http_timeout_grace = 5
129 max_requests_outstanding = 6
130 max_batch_up = 4000
131 http_retry = 5
132 port = 80
133 vroutes = ''
134
135 #[server] or [<client>] overrides
136 ipif = userv root ipif %(local)s,%(peer)s,%(mtu)s,slip %(rnets)s
137
138 # relating to virtual network
139 mtu = 1500
140
141 [SERVER]
142 server = SERVER
143 # addrs = 127.0.0.1 ::1
144 # url
145
146 # relating to virtual network
147 vvnetwork = 172.24.230.192
148 # vnetwork = <prefix>/<len>
149 # vadd  r  = <ipaddr>
150 # vrelay   = <ipaddr>
151
152
153 # [<client-ip4-or-ipv6-address>]
154 # password = <password>    # used by both, must match
155
156 [LIMIT]
157 max_batch_down = 262144
158 max_queue_time = 121
159 http_timeout = 121
160 target_requests_outstanding = 10
161 '''
162
163 # these need to be defined here so that they can be imported by import *
164 cfg = ConfigParser(strict=False)
165 optparser = OptionParser()
166
167 _mimetrans = bytes.maketrans(b'-'+slip.esc, slip.esc+b'-')
168 def mime_translate(s):
169   # SLIP-encoded packets cannot contain ESC ESC.
170   # Swap `-' and ESC.  The result cannot contain `--'
171   return s.translate(_mimetrans)
172
173 class ConfigResults:
174   def __init__(self):
175     pass
176   def __repr__(self):
177     return 'ConfigResults('+repr(self.__dict__)+')'
178
179 def log_discard(packet, iface, saddr, daddr, why):
180   log_debug(DBG.DROP,
181             'discarded packet [%s] %s -> %s: %s' % (iface, saddr, daddr, why),
182             d=packet)
183
184 #---------- packet parsing ----------
185
186 def packet_addrs(packet):
187   version = packet[0] >> 4
188   if version == 4:
189     addrlen = 4
190     saddroff = 3*4
191     factory = ipaddress.IPv4Address
192   elif version == 6:
193     addrlen = 16
194     saddroff = 2*4
195     factory = ipaddress.IPv6Address
196   else:
197     raise ValueError('unsupported IP version %d' % version)
198   saddr = factory(packet[ saddroff           : saddroff + addrlen   ])
199   daddr = factory(packet[ saddroff + addrlen : saddroff + addrlen*2 ])
200   return (saddr, daddr)
201
202 #---------- address handling ----------
203
204 def ipaddr(input):
205   try:
206     r = ipaddress.IPv4Address(input)
207   except AddressValueError:
208     r = ipaddress.IPv6Address(input)
209   return r
210
211 def ipnetwork(input):
212   try:
213     r = ipaddress.IPv4Network(input)
214   except NetworkValueError:
215     r = ipaddress.IPv6Network(input)
216   return r
217
218 #---------- ipif (SLIP) subprocess ----------
219
220 class SlipStreamDecoder():
221   def __init__(self, desc, on_packet):
222     self._buffer = b''
223     self._on_packet = on_packet
224     self._desc = desc
225     self._log('__init__')
226
227   def _log(self, msg, **kwargs):
228     log_debug(DBG.SLIP_FULL, 'slip %s: %s' % (self._desc, msg), **kwargs)
229
230   def inputdata(self, data):
231     self._log('inputdata', d=data)
232     data = self._buffer + data
233     self._buffer = b''
234     packets = slip.decode(data, True)
235     self._buffer = packets.pop()
236     for packet in packets:
237       self._maybe_packet(packet)
238     self._log('bufremain', d=self._buffer)
239
240   def _maybe_packet(self, packet):
241     self._log('maybepacket', d=packet)
242     if len(packet):
243       self._on_packet(packet)
244
245   def flush(self):
246     self._log('flush')
247     data = self._buffer
248     self._buffer = b''
249     packets = slip.decode(data)
250     assert(len(packets) == 1)
251     self._maybe_packet(packets[0])
252
253 class _IpifProcessProtocol(twisted.internet.protocol.ProcessProtocol):
254   def __init__(self, router):
255     self._router = router
256     self._decoder = SlipStreamDecoder('ipif', self.slip_on_packet)
257   def connectionMade(self): pass
258   def outReceived(self, data):
259     self._decoder.inputdata(data)
260   def slip_on_packet(self, packet):
261     (saddr, daddr) = packet_addrs(packet)
262     if saddr.is_link_local or daddr.is_link_local:
263       log_discard(packet, 'ipif', saddr, daddr, 'link-local')
264       return
265     self._router(packet, saddr, daddr)
266   def processEnded(self, status):
267     status.raiseException()
268
269 def start_ipif(command, router):
270   ipif = _IpifProcessProtocol(router)
271   reactor.spawnProcess(ipif,
272                        '/bin/sh',['sh','-xc', command],
273                        childFDs={0:'w', 1:'r', 2:2},
274                        env=None)
275   return ipif
276
277 def queue_inbound(ipif, packet):
278   log_debug(DBG.FLOW, "queue_inbound", d=packet)
279   ipif.transport.write(slip.delimiter)
280   ipif.transport.write(slip.encode(packet))
281   ipif.transport.write(slip.delimiter)
282
283 #---------- packet queue ----------
284
285 class PacketQueue():
286   def __init__(self, desc, max_queue_time):
287     self._desc = desc
288     assert(desc + '')
289     self._max_queue_time = max_queue_time
290     self._pq = collections.deque() # packets
291
292   def _log(self, dflag, msg, **kwargs):
293     log_debug(dflag, self._desc+' pq: '+msg, **kwargs)
294
295   def append(self, packet):
296     self._log(DBG.QUEUE, 'append', d=packet)
297     self._pq.append((time.monotonic(), packet))
298
299   def nonempty(self):
300     self._log(DBG.QUEUE, 'nonempty ?')
301     while True:
302       try: (queuetime, packet) = self._pq[0]
303       except IndexError:
304         self._log(DBG.QUEUE, 'nonempty ? empty.')
305         return False
306
307       age = time.monotonic() - queuetime
308       if age > self._max_queue_time:
309         # strip old packets off the front
310         self._log(DBG.QUEUE, 'dropping (old)', d=packet)
311         self._pq.popleft()
312         continue
313
314       self._log(DBG.QUEUE, 'nonempty ? nonempty.')
315       return True
316
317   def process(self, sizequery, moredata, max_batch):
318     # sizequery() should return size of batch so far
319     # moredata(s) should add s to batch
320     self._log(DBG.QUEUE, 'process...')
321     while True:
322       try: (dummy, packet) = self._pq[0]
323       except IndexError:
324         self._log(DBG.QUEUE, 'process... empty')
325         break
326
327       self._log(DBG.QUEUE_CTRL, 'process... packet', d=packet)
328
329       encoded = slip.encode(packet)
330       sofar = sizequery()  
331
332       self._log(DBG.QUEUE_CTRL,
333                 'process... (sofar=%d, max=%d) encoded' % (sofar, max_batch),
334                 d=encoded)
335
336       if sofar > 0:
337         if sofar + len(slip.delimiter) + len(encoded) > max_batch:
338           self._log(DBG.QUEUE_CTRL, 'process... overflow')
339           break
340         moredata(slip.delimiter)
341
342       moredata(encoded)
343       self._pq.popleft()
344
345 #---------- error handling ----------
346
347 _crashing = False
348
349 def crash(err):
350   global _crashing
351   _crashing = True
352   print('========== CRASH ==========', err,
353         '===========================', file=sys.stderr)
354   try: reactor.stop()
355   except twisted.internet.error.ReactorNotRunning: pass
356
357 def crash_on_defer(defer):
358   defer.addErrback(lambda err: crash(err))
359
360 def crash_on_critical(event):
361   if event.get('log_level') >= LogLevel.critical:
362     crash(twisted.logger.formatEvent(event))
363
364 #---------- config processing ----------
365
366 def _cfg_process_putatives():
367   servers = { }
368   clients = { }
369   # maps from abstract object to canonical name for cs's
370
371   def putative(cmap, abstract, canoncs):
372     try:
373       current_canoncs = cmap[abstract]
374     except KeyError:
375       pass
376     else:
377       assert(current_canoncs == canoncs)
378     cmap[abstract] = canoncs
379
380   server_pat = r'[-.0-9A-Za-z]+'
381   client_pat = r'[.:0-9a-f]+'
382   server_re = regexp.compile(server_pat)
383   serverclient_re = regexp.compile(server_pat + r' ' + client_pat)
384
385   for cs in cfg.sections():
386     if cs == 'LIMIT':
387       # plan A "[LIMIT]"
388       continue
389
390     try:
391       # plan B "[<client>]" part 1
392       ci = ipaddr(cs)
393     except AddressValueError:
394
395       if server_re.fullmatch(cs):
396         # plan C "[<servername>]"
397         putative(servers, cs, cs)
398         continue
399
400       if serverclient_re.fullmatch(cs):
401         # plan D "[<servername> <client>]" part 1
402         (pss,pcs) = cs.split(' ')
403
404         if pcs == 'LIMIT':
405           # plan E "[<servername> LIMIT]"
406           continue
407
408         try:
409           # plan D "[<servername> <client>]" part 2
410           ci = ipaddr(pc)
411         except AddressValueError:
412           # plan F "[<some thing we do not understand>]"
413           # well, we ignore this
414           print('warning: ignoring config section %s' % cs, file=sys.stderr)
415           continue
416
417         else: # no AddressValueError
418           # plan D "[<servername> <client]" part 3
419           putative(clients, ci, pcs)
420           putative(servers, pss, pss)
421           continue
422
423     else: # no AddressValueError
424       # plan B "[<client>" part 2
425       putative(clients, ci, cs)
426       continue
427
428   return (servers, clients)
429
430 def cfg_process_common(c, ss):
431   c.mtu = cfg.getint(ss, 'mtu')
432
433 def cfg_process_saddrs(c, ss):
434   class ServerAddr():
435     def __init__(self, port, addrspec):
436       self.port = port
437       # also self.addr
438       try:
439         self.addr = ipaddress.IPv4Address(addrspec)
440         self._endpointfactory = twisted.internet.endpoints.TCP4ServerEndpoint
441         self._inurl = b'%s'
442       except AddressValueError:
443         self.addr = ipaddress.IPv6Address(addrspec)
444         self._endpointfactory = twisted.internet.endpoints.TCP6ServerEndpoint
445         self._inurl = b'[%s]'
446     def make_endpoint(self):
447       return self._endpointfactory(reactor, self.port, self.addr)
448     def url(self):
449       url = b'http://' + (self._inurl % str(self.addr).encode('ascii'))
450       if self.port != 80: url += b':%d' % self.port
451       url += b'/'
452       return url
453
454   c.port = cfg.getint(ss,'port')
455   c.saddrs = [ ]
456   for addrspec in cfg.get(ss, 'addrs').split():
457     sa = ServerAddr(c.port, addrspec)
458     c.saddrs.append(sa)
459
460 def cfg_process_vnetwork(c, ss):
461   c.vnetwork = ipnetwork(cfg.get(ss,'vnetwork'))
462   if c.vnetwork.num_addresses < 3 + 2:
463     raise ValueError('vnetwork needs at least 2^3 addresses')
464
465 def cfg_process_vaddr(c, ss):
466   try:
467     c.vaddr = cfg.get(ss,'vaddr')
468   except NoOptionError:
469     cfg_process_vnetwork(c, ss)
470     c.vaddr = next(c.vnetwork.hosts())
471
472 def cfg_search_section(key,sections):
473   for section in sections:
474     if cfg.has_option(section, key):
475       return section
476   raise NoOptionError(key, repr(sections))
477
478 def cfg_search(getter,key,sections):
479   section = cfg_search_section(key,sections)
480   return getter(section, key)
481
482 def cfg_process_client_limited(cc,ss,sections,key):
483   val = cfg_search(cfg.getint, key, sections)
484   lim = cfg_search(cfg.getint, key, ['%s LIMIT' % ss, 'LIMIT'])
485   cc.__dict__[key] = min(val,lim)
486
487 def cfg_process_client_common(cc,ss,cs,ci):
488   # returns sections to search in, iff password is defined, otherwise None
489   cc.ci = ci
490
491   sections = ['%s %s' % (ss,cs),
492               cs,
493               ss,
494               'DEFAULT']
495
496   try: pwsection = cfg_search_section('password', sections)
497   except NoOptionError: return None
498     
499   pw = cfg.get(pwsection, 'password')
500   cc.password = pw.encode('utf-8')
501
502   cfg_process_client_limited(cc,ss,sections,'target_requests_outstanding')
503   cfg_process_client_limited(cc,ss,sections,'http_timeout')
504
505   return sections
506
507 def cfg_process_ipif(c, sections, varmap):
508   for d, s in varmap:
509     try: v = getattr(c, s)
510     except AttributeError: continue
511     setattr(c, d, v)
512
513   #print('CFGIPIF',repr((varmap, sections, c.__dict__)),file=sys.stderr)
514
515   section = cfg_search_section('ipif', sections)
516   c.ipif_command = cfg.get(section,'ipif', vars=c.__dict__)
517
518 #---------- startup ----------
519
520 def common_startup(process_cfg):
521   # calls process_cfg(putative_clients, putative_servers)
522
523   # ConfigParser hates #-comments after values
524   trailingcomments_re = regexp.compile(r'#.*')
525   cfg.read_string(trailingcomments_re.sub('', defcfg))
526   need_defcfg = True
527
528   def readconfig(pathname, mandatory=True):
529     def log(m, p=pathname):
530       if not DBG.CONFIG in debug_set: return
531       print('DBG.CONFIG: %s: %s' % (m, pathname))
532
533     try:
534       files = os.listdir(pathname)
535
536     except FileNotFoundError:
537       if mandatory: raise
538       log('skipped')
539       return
540
541     except NotADirectoryError:
542       cfg.read(pathname)
543       log('read file')
544       return
545
546     # is a directory
547     log('directory')
548     re = regexp.compile('[^-A-Za-z0-9_]')
549     for f in os.listdir(pathname):
550       if re.search(f): continue
551       subpath = pathname + '/' + f
552       try:
553         os.stat(subpath)
554       except FileNotFoundError:
555         log('entry skipped', subpath)
556         continue
557       cfg.read(subpath)
558       log('entry read', subpath)
559       
560   def oc_config(od,os, value, op):
561     nonlocal need_defcfg
562     need_defcfg = False
563     readconfig(value)
564
565   def oc_extra_config(od,os, value, op):
566     readconfig(value)
567
568   def read_defconfig():
569     readconfig('/etc/hippotat/config.d', False)
570     readconfig('/etc/hippotat/passwords.d', False)
571     readconfig('/etc/hippotat/master.cfg',   False)
572
573   def oc_defconfig(od,os, value, op):
574     nonlocal need_defcfg
575     need_defcfg = False
576     read_defconfig(value)
577
578   def dfs_less_detailed(dl):
579     return [df for df in DBG.iterconstants() if df <= dl]
580
581   def ds_default(od,os,dl,op):
582     global debug_set
583     debug_set.clear
584     debug_set |= set(dfs_less_detailed(debug_def_detail))
585
586   def ds_select(od,os, spec, op):
587     for it in spec.split(','):
588
589       if it.startswith('-'):
590         mutator = debug_set.discard
591         it = it[1:]
592       else:
593         mutator = debug_set.add
594
595       if it == '+':
596         dfs = DBG.iterconstants()
597
598       else:
599         if it.endswith('+'):
600           mapper = dfs_less_detailed
601           it = it[0:len(it)-1]
602         else:
603           mapper = lambda x: [x]
604
605           try:
606             dfspec = DBG.lookupByName(it)
607           except ValueError:
608             optparser.error('unknown debug flag %s in --debug-select' % it)
609
610         dfs = mapper(dfspec)
611
612       for df in dfs:
613         mutator(df)
614
615   optparser.add_option('-D', '--debug',
616                        nargs=0,
617                        action='callback',
618                        help='enable default debug (to stdout)',
619                        callback= ds_default)
620
621   optparser.add_option('--debug-select',
622                        nargs=1,
623                        type='string',
624                        metavar='[-]DFLAG[+]|[-]+,...',
625                        help=
626 '''enable (`-': disable) each specified DFLAG;
627 `+': do same for all "more interesting" DFLAGSs;
628 just `+': all DFLAGs.
629   DFLAGS: ''' + ' '.join([df.name for df in DBG.iterconstants()]),
630                        action='callback',
631                        callback= ds_select)
632
633   optparser.add_option('-c', '--config',
634                        nargs=1,
635                        type='string',
636                        metavar='CONFIGFILE',
637                        dest='configfile',
638                        action='callback',
639                        callback= oc_config)
640
641   optparser.add_option('--extra-config',
642                        nargs=1,
643                        type='string',
644                        metavar='CONFIGFILE',
645                        dest='configfile',
646                        action='callback',
647                        callback= oc_extra_config)
648
649   optparser.add_option('--default-config',
650                        action='callback',
651                        callback= oc_defconfig)
652
653   (opts, args) = optparser.parse_args()
654   if len(args): optparser.error('no non-option arguments please')
655
656   if need_defcfg:
657     read_defconfig()
658
659   try:
660     (pss, pcs) = _cfg_process_putatives()
661     process_cfg(opts, pss, pcs)
662   except (configparser.Error, ValueError):
663     traceback.print_exc(file=sys.stderr)
664     print('\nInvalid configuration, giving up.', file=sys.stderr)
665     sys.exit(12)
666
667
668   #print('X', debug_set, file=sys.stderr)
669
670   log_formatter = twisted.logger.formatEventAsClassicLogText
671   stdout_obs = twisted.logger.FileLogObserver(sys.stdout, log_formatter)
672   stderr_obs = twisted.logger.FileLogObserver(sys.stderr, log_formatter)
673   pred = twisted.logger.LogLevelFilterPredicate(LogLevel.error)
674   stdsomething_obs = twisted.logger.FilteringLogObserver(
675     stderr_obs, [pred], stdout_obs
676   )
677   global file_log_observer
678   file_log_observer = twisted.logger.FilteringLogObserver(
679     stdsomething_obs, [LogNotBoringTwisted()]
680   )
681   #log_observer = stdsomething_obs
682   twisted.logger.globalLogBeginner.beginLoggingTo(
683     [ file_log_observer, crash_on_critical ]
684     )
685
686 def common_run():
687   log_debug(DBG.INIT, 'entering reactor')
688   if not _crashing: reactor.run()
689   print('ENDED', file=sys.stderr)