chiark / gitweb /
Revert "config: Check args after reading config"
[hippotat.git] / hippotatlib / __init__.py
1 # -*- python -*-
2 #
3 # Hippotat - Asinine IP Over HTTP program
4 # hippotatlib/__init__.py - common library code
5 #
6 # Copyright 2017 Ian Jackson
7 #
8 # GPLv3+
9 #
10 #    This program is free software: you can redistribute it and/or modify
11 #    it under the terms of the GNU General Public License as published by
12 #    the Free Software Foundation, either version 3 of the License, or
13 #    (at your option) any later version.
14 #
15 #    This program is distributed in the hope that it will be useful,
16 #    but WITHOUT ANY WARRANTY; without even the implied warranty of
17 #    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
18 #    GNU General Public License for more details.
19 #
20 #    You should have received a copy of the GNU General Public License
21 #    along with this program, in the file GPLv3.  If not,
22 #    see <http://www.gnu.org/licenses/>.
23
24
25 import signal
26 signal.signal(signal.SIGINT, signal.SIG_DFL)
27
28 import sys
29 import os
30
31 from zope.interface import implementer
32
33 import twisted
34 from twisted.internet import reactor
35 import twisted.internet.endpoints
36 import twisted.logger
37 from twisted.logger import LogLevel
38 import twisted.python.constants
39 from twisted.python.constants import NamedConstant
40
41 import ipaddress
42 from ipaddress import AddressValueError
43
44 from optparse import OptionParser
45 import configparser
46 from configparser import ConfigParser
47 from configparser import NoOptionError
48
49 from functools import partial
50
51 import collections
52 import time
53 import codecs
54 import traceback
55
56 import re as regexp
57
58 import hippotatlib.slip as slip
59
60 class DBG(twisted.python.constants.Names):
61   INIT = NamedConstant()
62   CONFIG = NamedConstant()
63   ROUTE = NamedConstant()
64   DROP = NamedConstant()
65   OWNSOURCE = NamedConstant()
66   FLOW = NamedConstant()
67   HTTP = NamedConstant()
68   TWISTED = NamedConstant()
69   QUEUE = NamedConstant()
70   HTTP_CTRL = NamedConstant()
71   QUEUE_CTRL = NamedConstant()
72   HTTP_FULL = NamedConstant()
73   CTRL_DUMP = NamedConstant()
74   SLIP_FULL = NamedConstant()
75   DATA_COMPLETE = NamedConstant()
76
77 _hex_codec = codecs.getencoder('hex_codec')
78
79 #---------- logging ----------
80
81 org_stderr = sys.stderr
82
83 log = twisted.logger.Logger()
84
85 debug_set = set()
86 debug_def_detail = DBG.HTTP
87
88 def log_debug(dflag, msg, idof=None, d=None):
89   if dflag not in debug_set: return
90   #print('---------------->',repr((dflag, msg, idof, d)), file=sys.stderr)
91   if idof is not None:
92     msg = '[%#x] %s' % (id(idof), msg)
93   if d is not None:
94     trunc = ''
95     if not DBG.DATA_COMPLETE in debug_set:
96       if len(d) > 64:
97         d = d[0:64]
98         trunc = '...'
99     d = _hex_codec(d)[0].decode('ascii')
100     msg += ' ' + d + trunc
101   log.info('{dflag} {msgcore}', dflag=dflag, msgcore=msg)
102
103 @implementer(twisted.logger.ILogFilterPredicate)
104 class LogNotBoringTwisted:
105   def __call__(self, event):
106     yes = twisted.logger.PredicateResult.yes
107     no  = twisted.logger.PredicateResult.no
108     try:
109       if event.get('log_level') != LogLevel.info:
110         return yes
111       dflag = event.get('dflag')
112       if dflag is False                            : return yes
113       if dflag                         in debug_set: return yes
114       if dflag is None and DBG.TWISTED in debug_set: return yes
115       return no
116     except Exception:
117       print(traceback.format_exc(), file=org_stderr)
118       return yes
119
120 #---------- default config ----------
121
122 defcfg = '''
123 [DEFAULT]
124 max_batch_down = 65536
125 max_queue_time = 10
126 target_requests_outstanding = 3
127 http_timeout = 30
128 http_timeout_grace = 5
129 max_requests_outstanding = 6
130 max_batch_up = 4000
131 http_retry = 5
132 port = 80
133 vroutes = ''
134
135 #[server] or [<client>] overrides
136 ipif = userv root ipif %(local)s,%(peer)s,%(mtu)s,slip %(rnets)s
137
138 # relating to virtual network
139 mtu = 1500
140
141 [SERVER]
142 server = SERVER
143 # addrs = 127.0.0.1 ::1
144 # url
145
146 # relating to virtual network
147 vvnetwork = 172.24.230.192
148 # vnetwork = <prefix>/<len>
149 # vadd  r  = <ipaddr>
150 # vrelay   = <ipaddr>
151
152
153 # [<client-ip4-or-ipv6-address>]
154 # password = <password>    # used by both, must match
155
156 [LIMIT]
157 max_batch_down = 262144
158 max_queue_time = 121
159 http_timeout = 121
160 target_requests_outstanding = 10
161 '''
162
163 # these need to be defined here so that they can be imported by import *
164 cfg = ConfigParser(strict=False)
165 optparser = OptionParser()
166
167 _mimetrans = bytes.maketrans(b'-'+slip.esc, slip.esc+b'-')
168 def mime_translate(s):
169   # SLIP-encoded packets cannot contain ESC ESC.
170   # Swap `-' and ESC.  The result cannot contain `--'
171   return s.translate(_mimetrans)
172
173 class ConfigResults:
174   def __init__(self):
175     pass
176   def __repr__(self):
177     return 'ConfigResults('+repr(self.__dict__)+')'
178
179 def log_discard(packet, iface, saddr, daddr, why):
180   log_debug(DBG.DROP,
181             'discarded packet [%s] %s -> %s: %s' % (iface, saddr, daddr, why),
182             d=packet)
183
184 #---------- packet parsing ----------
185
186 def packet_addrs(packet):
187   version = packet[0] >> 4
188   if version == 4:
189     addrlen = 4
190     saddroff = 3*4
191     factory = ipaddress.IPv4Address
192   elif version == 6:
193     addrlen = 16
194     saddroff = 2*4
195     factory = ipaddress.IPv6Address
196   else:
197     raise ValueError('unsupported IP version %d' % version)
198   saddr = factory(packet[ saddroff           : saddroff + addrlen   ])
199   daddr = factory(packet[ saddroff + addrlen : saddroff + addrlen*2 ])
200   return (saddr, daddr)
201
202 #---------- address handling ----------
203
204 def ipaddr(input):
205   try:
206     r = ipaddress.IPv4Address(input)
207   except AddressValueError:
208     r = ipaddress.IPv6Address(input)
209   return r
210
211 def ipnetwork(input):
212   try:
213     r = ipaddress.IPv4Network(input)
214   except NetworkValueError:
215     r = ipaddress.IPv6Network(input)
216   return r
217
218 #---------- ipif (SLIP) subprocess ----------
219
220 class SlipStreamDecoder():
221   def __init__(self, desc, on_packet):
222     self._buffer = b''
223     self._on_packet = on_packet
224     self._desc = desc
225     self._log('__init__')
226
227   def _log(self, msg, **kwargs):
228     log_debug(DBG.SLIP_FULL, 'slip %s: %s' % (self._desc, msg), **kwargs)
229
230   def inputdata(self, data):
231     self._log('inputdata', d=data)
232     data = self._buffer + data
233     self._buffer = b''
234     packets = slip.decode(data, True)
235     self._buffer = packets.pop()
236     for packet in packets:
237       self._maybe_packet(packet)
238     self._log('bufremain', d=self._buffer)
239
240   def _maybe_packet(self, packet):
241     self._log('maybepacket', d=packet)
242     if len(packet):
243       self._on_packet(packet)
244
245   def flush(self):
246     self._log('flush')
247     data = self._buffer
248     self._buffer = b''
249     packets = slip.decode(data)
250     assert(len(packets) == 1)
251     self._maybe_packet(packets[0])
252
253 class _IpifProcessProtocol(twisted.internet.protocol.ProcessProtocol):
254   def __init__(self, router):
255     self._router = router
256     self._decoder = SlipStreamDecoder('ipif', self.slip_on_packet)
257   def connectionMade(self): pass
258   def outReceived(self, data):
259     self._decoder.inputdata(data)
260   def slip_on_packet(self, packet):
261     (saddr, daddr) = packet_addrs(packet)
262     if saddr.is_link_local or daddr.is_link_local:
263       log_discard(packet, 'ipif', saddr, daddr, 'link-local')
264       return
265     self._router(packet, saddr, daddr)
266   def processEnded(self, status):
267     status.raiseException()
268
269 def start_ipif(command, router):
270   ipif = _IpifProcessProtocol(router)
271   reactor.spawnProcess(ipif,
272                        '/bin/sh',['sh','-xc', command],
273                        childFDs={0:'w', 1:'r', 2:2},
274                        env=None)
275   return ipif
276
277 def queue_inbound(ipif, packet):
278   log_debug(DBG.FLOW, "queue_inbound", d=packet)
279   ipif.transport.write(slip.delimiter)
280   ipif.transport.write(slip.encode(packet))
281   ipif.transport.write(slip.delimiter)
282
283 #---------- packet queue ----------
284
285 class PacketQueue():
286   def __init__(self, desc, max_queue_time):
287     self._desc = desc
288     assert(desc + '')
289     self._max_queue_time = max_queue_time
290     self._pq = collections.deque() # packets
291
292   def _log(self, dflag, msg, **kwargs):
293     log_debug(dflag, self._desc+' pq: '+msg, **kwargs)
294
295   def append(self, packet):
296     self._log(DBG.QUEUE, 'append', d=packet)
297     self._pq.append((time.monotonic(), packet))
298
299   def nonempty(self):
300     self._log(DBG.QUEUE, 'nonempty ?')
301     while True:
302       try: (queuetime, packet) = self._pq[0]
303       except IndexError:
304         self._log(DBG.QUEUE, 'nonempty ? empty.')
305         return False
306
307       age = time.monotonic() - queuetime
308       if age > self._max_queue_time:
309         # strip old packets off the front
310         self._log(DBG.QUEUE, 'dropping (old)', d=packet)
311         self._pq.popleft()
312         continue
313
314       self._log(DBG.QUEUE, 'nonempty ? nonempty.')
315       return True
316
317   def process(self, sizequery, moredata, max_batch):
318     # sizequery() should return size of batch so far
319     # moredata(s) should add s to batch
320     self._log(DBG.QUEUE, 'process...')
321     while True:
322       try: (dummy, packet) = self._pq[0]
323       except IndexError:
324         self._log(DBG.QUEUE, 'process... empty')
325         break
326
327       self._log(DBG.QUEUE_CTRL, 'process... packet', d=packet)
328
329       encoded = slip.encode(packet)
330       sofar = sizequery()  
331
332       self._log(DBG.QUEUE_CTRL,
333                 'process... (sofar=%d, max=%d) encoded' % (sofar, max_batch),
334                 d=encoded)
335
336       if sofar > 0:
337         if sofar + len(slip.delimiter) + len(encoded) > max_batch:
338           self._log(DBG.QUEUE_CTRL, 'process... overflow')
339           break
340         moredata(slip.delimiter)
341
342       moredata(encoded)
343       self._pq.popleft()
344
345 #---------- error handling ----------
346
347 _crashing = False
348
349 def crash(err):
350   global _crashing
351   _crashing = True
352   print('========== CRASH ==========', err,
353         '===========================', file=sys.stderr)
354   try: reactor.stop()
355   except twisted.internet.error.ReactorNotRunning: pass
356
357 def crash_on_defer(defer):
358   defer.addErrback(lambda err: crash(err))
359
360 def crash_on_critical(event):
361   if event.get('log_level') >= LogLevel.critical:
362     crash(twisted.logger.formatEvent(event))
363
364 #---------- config processing ----------
365
366 def _cfg_process_putatives():
367   servers = { }
368   clients = { }
369   # maps from abstract object to canonical name for cs's
370
371   def putative(cmap, abstract, canoncs):
372     try:
373       current_canoncs = cmap[abstract]
374     except KeyError:
375       pass
376     else:
377       assert(current_canoncs == canoncs)
378     cmap[abstract] = canoncs
379
380   server_pat = r'[-.0-9A-Za-z]+'
381   client_pat = r'[.:0-9a-f]+'
382   server_re = regexp.compile(server_pat)
383   serverclient_re = regexp.compile(server_pat + r' ' + client_pat)
384
385   for cs in cfg.sections():
386     if cs == 'LIMIT':
387       # plan A "[LIMIT]"
388       continue
389
390     try:
391       # plan B "[<client>]" part 1
392       ci = ipaddr(cs)
393     except AddressValueError:
394
395       if server_re.fullmatch(cs):
396         # plan C "[<servername>]"
397         putative(servers, cs, cs)
398         continue
399
400       if serverclient_re.fullmatch(cs):
401         # plan D "[<servername> <client>]" part 1
402         (pss,pcs) = cs.split(' ')
403
404         if pcs == 'LIMIT':
405           # plan E "[<servername> LIMIT]"
406           continue
407
408         try:
409           # plan D "[<servername> <client>]" part 2
410           ci = ipaddr(pc)
411         except AddressValueError:
412           # plan F "[<some thing we do not understand>]"
413           # well, we ignore this
414           print('warning: ignoring config section %s' % cs, file=sys.stderr)
415           continue
416
417         else: # no AddressValueError
418           # plan D "[<servername> <client]" part 3
419           putative(clients, ci, pcs)
420           putative(servers, pss, pss)
421           continue
422
423     else: # no AddressValueError
424       # plan B "[<client>" part 2
425       putative(clients, ci, cs)
426       continue
427
428   return (servers, clients)
429
430 def cfg_process_common(c, ss):
431   c.mtu = cfg.getint(ss, 'mtu')
432
433 def cfg_process_saddrs(c, ss):
434   class ServerAddr():
435     def __init__(self, port, addrspec):
436       self.port = port
437       # also self.addr
438       try:
439         self.addr = ipaddress.IPv4Address(addrspec)
440         self._endpointfactory = twisted.internet.endpoints.TCP4ServerEndpoint
441         self._inurl = b'%s'
442       except AddressValueError:
443         self.addr = ipaddress.IPv6Address(addrspec)
444         self._endpointfactory = twisted.internet.endpoints.TCP6ServerEndpoint
445         self._inurl = b'[%s]'
446     def make_endpoint(self):
447       return self._endpointfactory(reactor, self.port, self.addr)
448     def url(self):
449       url = b'http://' + (self._inurl % str(self.addr).encode('ascii'))
450       if self.port != 80: url += b':%d' % self.port
451       url += b'/'
452       return url
453
454   c.port = cfg.getint(ss,'port')
455   c.saddrs = [ ]
456   for addrspec in cfg.get(ss, 'addrs').split():
457     sa = ServerAddr(c.port, addrspec)
458     c.saddrs.append(sa)
459
460 def cfg_process_vnetwork(c, ss):
461   c.vnetwork = ipnetwork(cfg.get(ss,'vnetwork'))
462   if c.vnetwork.num_addresses < 3 + 2:
463     raise ValueError('vnetwork needs at least 2^3 addresses')
464
465 def cfg_process_vaddr(c, ss):
466   try:
467     c.vaddr = cfg.get(ss,'vaddr')
468   except NoOptionError:
469     cfg_process_vnetwork(c, ss)
470     c.vaddr = next(c.vnetwork.hosts())
471
472 def cfg_search_section(key,sections):
473   for section in sections:
474     if cfg.has_option(section, key):
475       return section
476   raise NoOptionError(key, repr(sections))
477
478 def cfg_search(getter,key,sections):
479   section = cfg_search_section(key,sections)
480   return getter(section, key)
481
482 def cfg_process_client_limited(cc,ss,sections,key):
483   val = cfg_search(cfg.getint, key, sections)
484   lim = cfg_search(cfg.getint, key, ['%s LIMIT' % ss, 'LIMIT'])
485   cc.__dict__[key] = min(val,lim)
486
487 def cfg_process_client_common(cc,ss,cs,ci):
488   # returns sections to search in, iff password is defined, otherwise None
489   cc.ci = ci
490
491   sections = ['%s %s' % (ss,cs),
492               cs,
493               ss,
494               'DEFAULT']
495
496   try: pwsection = cfg_search_section('password', sections)
497   except NoOptionError: return None
498     
499   pw = cfg.get(pwsection, 'password')
500   cc.password = pw.encode('utf-8')
501
502   cfg_process_client_limited(cc,ss,sections,'target_requests_outstanding')
503   cfg_process_client_limited(cc,ss,sections,'http_timeout')
504
505   return sections
506
507 def cfg_process_ipif(c, sections, varmap):
508   for d, s in varmap:
509     try: v = getattr(c, s)
510     except AttributeError: continue
511     setattr(c, d, v)
512
513   #print('CFGIPIF',repr((varmap, sections, c.__dict__)),file=sys.stderr)
514
515   section = cfg_search_section('ipif', sections)
516   c.ipif_command = cfg.get(section,'ipif', vars=c.__dict__)
517
518 #---------- startup ----------
519
520 def common_startup(process_cfg):
521   # calls process_cfg(putative_clients, putative_servers)
522
523   # ConfigParser hates #-comments after values
524   trailingcomments_re = regexp.compile(r'#.*')
525   cfg.read_string(trailingcomments_re.sub('', defcfg))
526   need_defcfg = True
527
528   def readconfig(pathname, mandatory=True):
529     def log(m, p=pathname):
530       if not DBG.CONFIG in debug_set: return
531       print('DBG.CONFIG: %s: %s' % (m, pathname))
532
533     try:
534       files = os.listdir(pathname)
535
536     except FileNotFoundError:
537       if mandatory: raise
538       log('skipped')
539       return
540
541     except NotADirectoryError:
542       cfg.read(pathname)
543       log('read file')
544       return
545
546     # is a directory
547     log('directory')
548     re = regexp.compile('[^-A-Za-z0-9_]')
549     for f in os.listdir(cdir):
550       if re.search(f): continue
551       subpath = pathname + '/' + f
552       try:
553         os.stat(subpath)
554       except FileNotFoundError:
555         log('entry skipped', subpath)
556         continue
557       cfg.read(subpath)
558       log('entry read', subpath)
559       
560   def oc_config(od,os, value, op):
561     nonlocal need_defcfg
562     need_defcfg = False
563     readconfig(value)
564
565   def read_defconfig():
566     readconfig('/etc/hippotat/config.d', False)
567     readconfig('/etc/hippotat/passwords.d', False)
568     readconfig('/etc/hippotat/master.cfg',   False)
569
570   def dfs_less_detailed(dl):
571     return [df for df in DBG.iterconstants() if df <= dl]
572
573   def ds_default(od,os,dl,op):
574     global debug_set
575     debug_set.clear
576     debug_set |= set(dfs_less_detailed(debug_def_detail))
577
578   def ds_select(od,os, spec, op):
579     for it in spec.split(','):
580
581       if it.startswith('-'):
582         mutator = debug_set.discard
583         it = it[1:]
584       else:
585         mutator = debug_set.add
586
587       if it == '+':
588         dfs = DBG.iterconstants()
589
590       else:
591         if it.endswith('+'):
592           mapper = dfs_less_detailed
593           it = it[0:len(it)-1]
594         else:
595           mapper = lambda x: [x]
596
597           try:
598             dfspec = DBG.lookupByName(it)
599           except ValueError:
600             optparser.error('unknown debug flag %s in --debug-select' % it)
601
602         dfs = mapper(dfspec)
603
604       for df in dfs:
605         mutator(df)
606
607   optparser.add_option('-D', '--debug',
608                        nargs=0,
609                        action='callback',
610                        help='enable default debug (to stdout)',
611                        callback= ds_default)
612
613   optparser.add_option('--debug-select',
614                        nargs=1,
615                        type='string',
616                        metavar='[-]DFLAG[+]|[-]+,...',
617                        help=
618 '''enable (`-': disable) each specified DFLAG;
619 `+': do same for all "more interesting" DFLAGSs;
620 just `+': all DFLAGs.
621   DFLAGS: ''' + ' '.join([df.name for df in DBG.iterconstants()]),
622                        action='callback',
623                        callback= ds_select)
624
625   optparser.add_option('-c', '--config',
626                        nargs=1,
627                        type='string',
628                        metavar='CONFIGFILE',
629                        dest='configfile',
630                        action='callback',
631                        callback= oc_config)
632
633   (opts, args) = optparser.parse_args()
634   if len(args): optparser.error('no non-option arguments please')
635
636   if need_defcfg:
637     read_defconfig()
638
639   try:
640     (pss, pcs) = _cfg_process_putatives()
641     process_cfg(opts, pss, pcs)
642   except (configparser.Error, ValueError):
643     traceback.print_exc(file=sys.stderr)
644     print('\nInvalid configuration, giving up.', file=sys.stderr)
645     sys.exit(12)
646
647
648   #print('X', debug_set, file=sys.stderr)
649
650   log_formatter = twisted.logger.formatEventAsClassicLogText
651   stdout_obs = twisted.logger.FileLogObserver(sys.stdout, log_formatter)
652   stderr_obs = twisted.logger.FileLogObserver(sys.stderr, log_formatter)
653   pred = twisted.logger.LogLevelFilterPredicate(LogLevel.error)
654   stdsomething_obs = twisted.logger.FilteringLogObserver(
655     stderr_obs, [pred], stdout_obs
656   )
657   global file_log_observer
658   file_log_observer = twisted.logger.FilteringLogObserver(
659     stdsomething_obs, [LogNotBoringTwisted()]
660   )
661   #log_observer = stdsomething_obs
662   twisted.logger.globalLogBeginner.beginLoggingTo(
663     [ file_log_observer, crash_on_critical ]
664     )
665
666 def common_run():
667   log_debug(DBG.INIT, 'entering reactor')
668   if not _crashing: reactor.run()
669   print('ENDED', file=sys.stderr)