X-Git-Url: http://www.chiark.greenend.org.uk/ucgi/~ian/git?p=hippotat.git;a=blobdiff_plain;f=hippotatlib%2F__init__.py;h=de939c86b09621dbddb7c5f1445ad1bf4ed4b717;hp=701a5f20092520f6df510fb9f31bf734d4ffe923;hb=a14782d3bb7fe3e65f19e45d913d2e5f5d8662bb;hpb=1cc6968f38db0ade45242e08f9aab1b1db3e43b1 diff --git a/hippotatlib/__init__.py b/hippotatlib/__init__.py index 701a5f2..de939c8 100644 --- a/hippotatlib/__init__.py +++ b/hippotatlib/__init__.py @@ -50,6 +50,9 @@ from functools import partial import collections import time +import hmac +import hashlib +import base64 import codecs import traceback @@ -100,27 +103,33 @@ def log_debug(dflag, msg, idof=None, d=None): msg += ' ' + d + trunc log.info('{dflag} {msgcore}', dflag=dflag, msgcore=msg) +def logevent_is_boringtwisted(event): + try: + if event.get('log_level') != LogLevel.info: + return False + dflag = event.get('dflag') + if dflag is False : return False + if dflag in debug_set: return False + if dflag is None and DBG.TWISTED in debug_set: return False + return True + except Exception: + print('EXCEPTION (IN BORINGTWISTED CHECK)', + traceback.format_exc(), file=org_stderr) + return False + @implementer(twisted.logger.ILogFilterPredicate) class LogNotBoringTwisted: def __call__(self, event): - yes = twisted.logger.PredicateResult.yes - no = twisted.logger.PredicateResult.no - try: - if event.get('log_level') != LogLevel.info: - return yes - dflag = event.get('dflag') - if dflag is False : return yes - if dflag in debug_set: return yes - if dflag is None and DBG.TWISTED in debug_set: return yes - return no - except Exception: - print(traceback.format_exc(), file=org_stderr) - return yes + return ( + twisted.logger.PredicateResult.no + if logevent_is_boringtwisted(event) else + twisted.logger.PredicateResult.yes + ) #---------- default config ---------- defcfg = ''' -[DEFAULT] +[COMMON] max_batch_down = 65536 max_queue_time = 10 target_requests_outstanding = 3 @@ -131,27 +140,28 @@ max_batch_up = 4000 http_retry = 5 port = 80 vroutes = '' +ifname_client = hippo%%d +ifname_server = shippo%%d +max_clock_skew = 300 #[server] or [] overrides -ipif = userv root ipif %(local)s,%(peer)s,%(mtu)s,slip %(rnets)s +ipif = userv root ipif %(local)s,%(peer)s,%(mtu)s,slip,%(ifname)s %(rnets)s # relating to virtual network mtu = 1500 -[SERVER] -server = SERVER # addrs = 127.0.0.1 ::1 # url # relating to virtual network vvnetwork = 172.24.230.192 # vnetwork = / -# vadd r = +# vaddr = # vrelay = # [] -# password = # used by both, must match +# secret = # used by both, must match [LIMIT] max_batch_down = 262144 @@ -361,6 +371,34 @@ def crash_on_critical(event): if event.get('log_level') >= LogLevel.critical: crash(twisted.logger.formatEvent(event)) +#---------- authentication tokens ---------- + +_authtoken_digest = hashlib.sha256 + +def _authtoken_time(): + return int(time.time()) + +def _authtoken_hmac(secret, hextime): + return hmac.new(secret, hextime, _authtoken_digest).digest() + +def authtoken_make(secret): + hextime = ('%x' % _authtoken_time()).encode('ascii') + mac = _authtoken_hmac(secret, hextime) + return hextime + b' ' + base64.b64encode(mac) + +def authtoken_check(secret, token, maxskew): + (hextime, theirmac64) = token.split(b' ') + now = _authtoken_time() + then = int(hextime, 16) + skew = then - now; + if (abs(skew) > maxskew): + raise ValueError('too much clock skew (client %ds ahead)' % skew) + theirmac = base64.b64decode(theirmac64) + ourmac = _authtoken_hmac(secret, hextime) + if not hmac.compare_digest(theirmac, ourmac): + raise ValueError('invalid token (wrong secret?)') + pass + #---------- config processing ---------- def _cfg_process_putatives(): @@ -380,11 +418,21 @@ def _cfg_process_putatives(): server_pat = r'[-.0-9A-Za-z]+' client_pat = r'[.:0-9a-f]+' server_re = regexp.compile(server_pat) - serverclient_re = regexp.compile(server_pat + r' ' + client_pat) + serverclient_re = regexp.compile( + server_pat + r' ' + '(?:' + client_pat + '|LIMIT)') for cs in cfg.sections(): - if cs == 'LIMIT': - # plan A "[LIMIT]" + def dbg(m): + log_debug_config('putatives: section [%s] %s' % (cs, m)) + + def log_ignore(why): + dbg('X ignore: %s' % (why)) + print('warning: ignoring config section [%s] (%s)' % (cs, why), + file=sys.stderr) + + if cs == 'LIMIT' or cs == 'COMMON': + # plan A "[LIMIT]" or "[COMMON]" + dbg('A ignore') continue try: @@ -394,6 +442,7 @@ def _cfg_process_putatives(): if server_re.fullmatch(cs): # plan C "[]" + dbg('C ') putative(servers, cs, cs) continue @@ -403,32 +452,37 @@ def _cfg_process_putatives(): if pcs == 'LIMIT': # plan E "[ LIMIT]" + dbg('E LIMIT') continue try: # plan D "[ ]" part 2 - ci = ipaddr(pc) + ci = ipaddr(pcs) except AddressValueError: - # plan F "[]" - # well, we ignore this - print('warning: ignoring config section %s' % cs, file=sys.stderr) + # plan F branch 1 "[]" + log_ignore('bad-addr') continue else: # no AddressValueError - # plan D "[ ]" part 3 + dbg('D ') putative(clients, ci, pcs) putative(servers, pss, pss) continue + else: + # plan F branch 2 "[]" + log_ignore('nomatch '+ repr(serverclient_re)) else: # no AddressValueError # plan B "[" part 2 + dbg('B ') putative(clients, ci, cs) continue return (servers, clients) -def cfg_process_common(c, ss): - c.mtu = cfg.getint(ss, 'mtu') +def cfg_process_general(c, ss): + c.mtu = cfg1getint(ss, 'mtu') def cfg_process_saddrs(c, ss): class ServerAddr(): @@ -444,27 +498,30 @@ def cfg_process_saddrs(c, ss): self._endpointfactory = twisted.internet.endpoints.TCP6ServerEndpoint self._inurl = b'[%s]' def make_endpoint(self): - return self._endpointfactory(reactor, self.port, self.addr) + return self._endpointfactory(reactor, self.port, + interface= '%s' % self.addr) def url(self): url = b'http://' + (self._inurl % str(self.addr).encode('ascii')) if self.port != 80: url += b':%d' % self.port url += b'/' return url + def __repr__(self): + return 'ServerAddr'+repr((self.port,self.addr)) - c.port = cfg.getint(ss,'port') + c.port = cfg1getint(ss,'port') c.saddrs = [ ] - for addrspec in cfg.get(ss, 'addrs').split(): + for addrspec in cfg1get(ss, 'addrs').split(): sa = ServerAddr(c.port, addrspec) c.saddrs.append(sa) def cfg_process_vnetwork(c, ss): - c.vnetwork = ipnetwork(cfg.get(ss,'vnetwork')) + c.vnetwork = ipnetwork(cfg1get(ss,'vnetwork')) if c.vnetwork.num_addresses < 3 + 2: raise ValueError('vnetwork needs at least 2^3 addresses') def cfg_process_vaddr(c, ss): try: - c.vaddr = cfg.get(ss,'vaddr') + c.vaddr = cfg1get(ss,'vaddr') except NoOptionError: cfg_process_vnetwork(c, ss) c.vaddr = next(c.vnetwork.hosts()) @@ -475,29 +532,40 @@ def cfg_search_section(key,sections): return section raise NoOptionError(key, repr(sections)) +def cfg_get_raw(*args, **kwargs): + # for passing to cfg_search + return cfg.get(*args, raw=True, **kwargs) + def cfg_search(getter,key,sections): section = cfg_search_section(key,sections) return getter(section, key) +def cfg1get(section,key, getter=cfg.get,**kwargs): + section = cfg_search_section(key,[section,'COMMON']) + return getter(section,key,**kwargs) + +def cfg1getint(section,key, **kwargs): + return cfg1get(section,key, getter=cfg.getint,**kwargs); + def cfg_process_client_limited(cc,ss,sections,key): - val = cfg_search(cfg.getint, key, sections) - lim = cfg_search(cfg.getint, key, ['%s LIMIT' % ss, 'LIMIT']) + val = cfg_search(cfg1getint, key, sections) + lim = cfg_search(cfg1getint, key, ['%s LIMIT' % ss, 'LIMIT']) cc.__dict__[key] = min(val,lim) def cfg_process_client_common(cc,ss,cs,ci): - # returns sections to search in, iff password is defined, otherwise None + # returns sections to search in, iff secret is defined, otherwise None cc.ci = ci sections = ['%s %s' % (ss,cs), cs, ss, - 'DEFAULT'] + 'COMMON'] - try: pwsection = cfg_search_section('password', sections) + try: pwsection = cfg_search_section('secret', sections) except NoOptionError: return None - pw = cfg.get(pwsection, 'password') - cc.password = pw.encode('utf-8') + pw = cfg1get(pwsection, 'secret') + cc.secret = pw.encode('utf-8') cfg_process_client_limited(cc,ss,sections,'target_requests_outstanding') cfg_process_client_limited(cc,ss,sections,'http_timeout') @@ -509,14 +577,21 @@ def cfg_process_ipif(c, sections, varmap): try: v = getattr(c, s) except AttributeError: continue setattr(c, d, v) + for d in ('mtu',): + v = cfg_search(cfg.get, d, sections) + setattr(c, d, v) #print('CFGIPIF',repr((varmap, sections, c.__dict__)),file=sys.stderr) section = cfg_search_section('ipif', sections) - c.ipif_command = cfg.get(section,'ipif', vars=c.__dict__) + c.ipif_command = cfg1get(section,'ipif', vars=c.__dict__) #---------- startup ---------- +def log_debug_config(m): + if not DBG.CONFIG in debug_set: return + print('DBG.CONFIG:', m) + def common_startup(process_cfg): # calls process_cfg(putative_clients, putative_servers) @@ -528,7 +603,7 @@ def common_startup(process_cfg): def readconfig(pathname, mandatory=True): def log(m, p=pathname): if not DBG.CONFIG in debug_set: return - print('DBG.CONFIG: %s: %s' % (m, pathname)) + log_debug_config('%s: %s' % (m, p)) try: files = os.listdir(pathname) @@ -546,7 +621,7 @@ def common_startup(process_cfg): # is a directory log('directory') re = regexp.compile('[^-A-Za-z0-9_]') - for f in os.listdir(cdir): + for f in os.listdir(pathname): if re.search(f): continue subpath = pathname + '/' + f try: @@ -562,6 +637,19 @@ def common_startup(process_cfg): need_defcfg = False readconfig(value) + def oc_extra_config(od,os, value, op): + readconfig(value) + + def read_defconfig(): + readconfig('/etc/hippotat/config.d', False) + readconfig('/etc/hippotat/secrets.d', False) + readconfig('/etc/hippotat/master.cfg', False) + + def oc_defconfig(od,os, value, op): + nonlocal need_defcfg + need_defcfg = False + read_defconfig(value) + def dfs_less_detailed(dl): return [df for df in DBG.iterconstants() if df <= dl] @@ -625,12 +713,23 @@ just `+': all DFLAGs. action='callback', callback= oc_config) + optparser.add_option('--extra-config', + nargs=1, + type='string', + metavar='CONFIGFILE', + dest='configfile', + action='callback', + callback= oc_extra_config) + + optparser.add_option('--default-config', + action='callback', + callback= oc_defconfig) + (opts, args) = optparser.parse_args() if len(args): optparser.error('no non-option arguments please') if need_defcfg: - readconfig('/etc/hippotat/config', False) - readconfig('/etc/hippotat/config.d', False) + read_defconfig() try: (pss, pcs) = _cfg_process_putatives() @@ -650,15 +749,17 @@ just `+': all DFLAGs. stdsomething_obs = twisted.logger.FilteringLogObserver( stderr_obs, [pred], stdout_obs ) - log_observer = twisted.logger.FilteringLogObserver( + global file_log_observer + file_log_observer = twisted.logger.FilteringLogObserver( stdsomething_obs, [LogNotBoringTwisted()] ) #log_observer = stdsomething_obs twisted.logger.globalLogBeginner.beginLoggingTo( - [ log_observer, crash_on_critical ] + [ file_log_observer, crash_on_critical ] ) def common_run(): log_debug(DBG.INIT, 'entering reactor') if not _crashing: reactor.run() - print('CRASHED (end)', file=sys.stderr) + print('ENDED', file=sys.stderr) + sys.exit(16)