chiark / gitweb /
fa43065759d82332d9f61c5b4d8305c5346bf32c
[hippotat.git] / hippotat / __init__.py
1 # -*- python -*-
2
3 import signal
4 signal.signal(signal.SIGINT, signal.SIG_DFL)
5
6 import sys
7
8 import twisted
9 from twisted.internet import reactor
10 import twisted.internet.endpoints
11 import twisted.logger
12 from twisted.logger import LogLevel
13 import twisted.python.constants
14 from twisted.python.constants import NamedConstant
15
16 import ipaddress
17 from ipaddress import AddressValueError
18
19 from optparse import OptionParser
20 from configparser import ConfigParser
21 from configparser import NoOptionError
22
23 from functools import partial
24
25 import collections
26 import time
27 import codecs
28 import traceback
29
30 import re as regexp
31
32 import hippotat.slip as slip
33
34 class DBG(twisted.python.constants.Names):
35   ROUTE = NamedConstant()
36   DROP = NamedConstant()
37   FLOW = NamedConstant()
38   HTTP = NamedConstant()
39   HTTP_CTRL = NamedConstant()
40   INIT = NamedConstant()
41   QUEUE = NamedConstant()
42   QUEUE_CTRL = NamedConstant()
43   HTTP_FULL = NamedConstant()
44   SLIP_FULL = NamedConstant()
45   CTRL_DUMP = NamedConstant()
46
47 _hex_codec = codecs.getencoder('hex_codec')
48
49 log = twisted.logger.Logger()
50
51 def log_debug(dflag, msg, idof=None, d=None):
52   #print('---------------->',repr((dflag, msg, idof, d)), file=sys.stderr)
53   if idof is not None:
54     msg = '[%#x] %s' % (id(idof), msg)
55   if d is not None:
56     #d = d[0:64]
57     d = _hex_codec(d)[0].decode('ascii')
58     msg += ' ' + d
59   log.info('{dflag} {msgcore}', dflag=dflag, msgcore=msg)
60
61 defcfg = '''
62 [DEFAULT]
63 #[<client>] overrides
64 max_batch_down = 65536           # used by server, subject to [limits]
65 max_queue_time = 10              # used by server, subject to [limits]
66 target_requests_outstanding = 3  # must match; subject to [limits] on server
67 http_timeout = 30                # used by both } must be
68 http_timeout_grace = 5           # used by both }  compatible
69 max_requests_outstanding = 4     # used by client
70 max_batch_up = 4000              # used by client
71 http_retry = 5                   # used by client
72
73 #[server] or [<client>] overrides
74 ipif = userv root ipif %(local)s,%(peer)s,%(mtu)s,slip %(rnets)s
75 # extra interpolations:  %(local)s        %(peer)s          %(rnet)s
76 #  obtained   on server  [virtual]server  [virtual]relay    [virtual]network
77 #      from   on client  <client>         [virtual]server   [virtual]routes
78
79 [virtual]
80 mtu = 1500
81 routes = ''
82 # network = <prefix>/<len>  # mandatory for server
83 # server  = <ipaddr>   # used by both, default is computed from `network'
84 # relay   = <ipaddr>   # used by server, default from `network' and `server'
85 #  default server is first host in network
86 #  default relay is first host which is not server
87
88 [server]
89 # addrs = 127.0.0.1 ::1    # mandatory for server
90 port = 80                  # used by server
91 # url              # used by client; default from first `addrs' and `port'
92
93 # [<client-ip4-or-ipv6-address>]
94 # password = <password>    # used by both, must match
95
96 [limits]
97 max_batch_down = 262144           # used by server
98 max_queue_time = 121              # used by server
99 http_timeout = 121                # used by server
100 target_requests_outstanding = 10  # used by server
101 '''
102
103 # these need to be defined here so that they can be imported by import *
104 cfg = ConfigParser()
105 optparser = OptionParser()
106
107 _mimetrans = bytes.maketrans(b'-'+slip.esc, slip.esc+b'-')
108 def mime_translate(s):
109   # SLIP-encoded packets cannot contain ESC ESC.
110   # Swap `-' and ESC.  The result cannot contain `--'
111   return s.translate(_mimetrans)
112
113 class ConfigResults:
114   def __init__(self, d = { }):
115     self.__dict__ = d
116   def __repr__(self):
117     return 'ConfigResults('+repr(self.__dict__)+')'
118
119 c = ConfigResults()
120
121 def log_discard(packet, iface, saddr, daddr, why):
122   log_debug(DBG.DROP,
123             'discarded packet [%s] %s -> %s: %s' % (iface, saddr, daddr, why),
124             d=packet)
125
126 #---------- packet parsing ----------
127
128 def packet_addrs(packet):
129   version = packet[0] >> 4
130   if version == 4:
131     addrlen = 4
132     saddroff = 3*4
133     factory = ipaddress.IPv4Address
134   elif version == 6:
135     addrlen = 16
136     saddroff = 2*4
137     factory = ipaddress.IPv6Address
138   else:
139     raise ValueError('unsupported IP version %d' % version)
140   saddr = factory(packet[ saddroff           : saddroff + addrlen   ])
141   daddr = factory(packet[ saddroff + addrlen : saddroff + addrlen*2 ])
142   return (saddr, daddr)
143
144 #---------- address handling ----------
145
146 def ipaddr(input):
147   try:
148     r = ipaddress.IPv4Address(input)
149   except AddressValueError:
150     r = ipaddress.IPv6Address(input)
151   return r
152
153 def ipnetwork(input):
154   try:
155     r = ipaddress.IPv4Network(input)
156   except NetworkValueError:
157     r = ipaddress.IPv6Network(input)
158   return r
159
160 #---------- ipif (SLIP) subprocess ----------
161
162 class SlipStreamDecoder():
163   def __init__(self, desc, on_packet):
164     self._buffer = b''
165     self._on_packet = on_packet
166     self._desc = desc
167     self._log('__init__')
168
169   def _log(self, msg, **kwargs):
170     log_debug(DBG.SLIP_FULL, 'slip %s: %s' % (self._desc, msg), **kwargs)
171
172   def inputdata(self, data):
173     self._log('inputdata', d=data)
174     packets = slip.decode(data)
175     packets[0] = self._buffer + packets[0]
176     self._buffer = packets.pop()
177     for packet in packets:
178       self._maybe_packet(packet)
179     self._log('bufremain', d=self._buffer)
180
181   def _maybe_packet(self, packet):
182     self._log('maybepacket', d=packet)
183     if len(packet):
184       self._on_packet(packet)
185
186   def flush(self):
187     self._log('flush')
188     self._maybe_packet(self._buffer)
189     self._buffer = b''
190
191 class _IpifProcessProtocol(twisted.internet.protocol.ProcessProtocol):
192   def __init__(self, router):
193     self._router = router
194     self._decoder = SlipStreamDecoder('ipif', self.slip_on_packet)
195   def connectionMade(self): pass
196   def outReceived(self, data):
197     self._decoder.inputdata(data)
198   def slip_on_packet(self, packet):
199     (saddr, daddr) = packet_addrs(packet)
200     if saddr.is_link_local or daddr.is_link_local:
201       log_discard(packet, 'ipif', saddr, daddr, 'link-local')
202       return
203     self._router(packet, saddr, daddr)
204   def processEnded(self, status):
205     status.raiseException()
206
207 def start_ipif(command, router):
208   global ipif
209   ipif = _IpifProcessProtocol(router)
210   reactor.spawnProcess(ipif,
211                        '/bin/sh',['sh','-xc', command],
212                        childFDs={0:'w', 1:'r', 2:2},
213                        env=None)
214
215 def queue_inbound(packet):
216   log_debug(DBG.FLOW, "queue_inbound", d=packet)
217   ipif.transport.write(slip.delimiter)
218   ipif.transport.write(slip.encode(packet))
219   ipif.transport.write(slip.delimiter)
220
221 #---------- packet queue ----------
222
223 class PacketQueue():
224   def __init__(self, desc, max_queue_time):
225     self._desc = desc
226     assert(desc + '')
227     self._max_queue_time = max_queue_time
228     self._pq = collections.deque() # packets
229
230   def _log(self, dflag, msg, **kwargs):
231     log_debug(dflag, self._desc+' pq: '+msg, **kwargs)
232
233   def append(self, packet):
234     self._log(DBG.QUEUE, 'append', d=packet)
235     self._pq.append((time.monotonic(), packet))
236
237   def nonempty(self):
238     self._log(DBG.QUEUE, 'nonempty ?')
239     while True:
240       try: (queuetime, packet) = self._pq[0]
241       except IndexError:
242         self._log(DBG.QUEUE, 'nonempty ? empty.')
243         return False
244
245       age = time.monotonic() - queuetime
246       if age > self._max_queue_time:
247         # strip old packets off the front
248         self._log(DBG.QUEUE, 'dropping (old)', d=packet)
249         self._pq.popleft()
250         continue
251
252       self._log(DBG.QUEUE, 'nonempty ? nonempty.')
253       return True
254
255   def process(self, sizequery, moredata, max_batch):
256     # sizequery() should return size of batch so far
257     # moredata(s) should add s to batch
258     self._log(DBG.QUEUE, 'process...')
259     while True:
260       try: (dummy, packet) = self._pq[0]
261       except IndexError:
262         self._log(DBG.QUEUE, 'process... empty')
263         break
264
265       self._log(DBG.QUEUE_CTRL, 'process... packet', d=packet)
266
267       encoded = slip.encode(packet)
268       sofar = sizequery()  
269
270       self._log(DBG.QUEUE_CTRL,
271                 'process... (sofar=%d, max=%d) encoded' % (sofar, max_batch),
272                 d=encoded)
273
274       if sofar > 0:
275         if sofar + len(slip.delimiter) + len(encoded) > max_batch:
276           self._log(DBG.QUEUE_CTRL, 'process... overflow')
277           break
278         moredata(slip.delimiter)
279
280       moredata(encoded)
281       self._pq.popleft()
282
283 #---------- error handling ----------
284
285 _crashing = False
286
287 def crash(err):
288   global _crashing
289   _crashing = True
290   print('========== CRASH ==========', err,
291         '===========================', file=sys.stderr)
292   try: reactor.stop()
293   except twisted.internet.error.ReactorNotRunning: pass
294
295 def crash_on_defer(defer):
296   defer.addErrback(lambda err: crash(err))
297
298 def crash_on_critical(event):
299   if event.get('log_level') >= LogLevel.critical:
300     crash(twisted.logger.formatEvent(event))
301
302 #---------- config processing ----------
303
304 def process_cfg_common_always():
305   global mtu
306   c.mtu = cfg.get('virtual','mtu')
307
308 def process_cfg_ipif(section, varmap):
309   for d, s in varmap:
310     try: v = getattr(c, s)
311     except AttributeError: continue
312     setattr(c, d, v)
313
314   #print(repr(c))
315
316   c.ipif_command = cfg.get(section,'ipif', vars=c.__dict__)
317
318 def process_cfg_network():
319   c.network = ipnetwork(cfg.get('virtual','network'))
320   if c.network.num_addresses < 3 + 2:
321     raise ValueError('network needs at least 2^3 addresses')
322
323 def process_cfg_server():
324   try:
325     c.server = cfg.get('virtual','server')
326   except NoOptionError:
327     process_cfg_network()
328     c.server = next(c.network.hosts())
329
330 class ServerAddr():
331   def __init__(self, port, addrspec):
332     self.port = port
333     # also self.addr
334     try:
335       self.addr = ipaddress.IPv4Address(addrspec)
336       self._endpointfactory = twisted.internet.endpoints.TCP4ServerEndpoint
337       self._inurl = b'%s'
338     except AddressValueError:
339       self.addr = ipaddress.IPv6Address(addrspec)
340       self._endpointfactory = twisted.internet.endpoints.TCP6ServerEndpoint
341       self._inurl = b'[%s]'
342   def make_endpoint(self):
343     return self._endpointfactory(reactor, self.port, self.addr)
344   def url(self):
345     url = b'http://' + (self._inurl % str(self.addr).encode('ascii'))
346     if self.port != 80: url += b':%d' % self.port
347     url += b'/'
348     return url
349     
350 def process_cfg_saddrs():
351   try: port = cfg.getint('server','port')
352   except NoOptionError: port = 80
353
354   c.saddrs = [ ]
355   for addrspec in cfg.get('server','addrs').split():
356     sa = ServerAddr(port, addrspec)
357     c.saddrs.append(sa)
358
359 def process_cfg_clients(constructor):
360   c.clients = [ ]
361   for cs in cfg.sections():
362     if not (':' in cs or '.' in cs): continue
363     ci = ipaddr(cs)
364     pw = cfg.get(cs, 'password')
365     pw = pw.encode('utf-8')
366     constructor(ci,cs,pw)
367
368 #---------- startup ----------
369
370 def common_startup():
371   log_formatter = twisted.logger.formatEventAsClassicLogText
372   stdout_obs = twisted.logger.FileLogObserver(sys.stdout, log_formatter)
373   stderr_obs = twisted.logger.FileLogObserver(sys.stderr, log_formatter)
374   pred = twisted.logger.LogLevelFilterPredicate(LogLevel.error)
375   log_observer = twisted.logger.FilteringLogObserver(
376     stderr_obs, [pred], stdout_obs
377   )
378   twisted.logger.globalLogBeginner.beginLoggingTo(
379     [ log_observer, crash_on_critical ]
380     )
381
382   optparser.add_option('-c', '--config', dest='configfile',
383                        default='/etc/hippotat/config')
384   (opts, args) = optparser.parse_args()
385   if len(args): optparser.error('no non-option arguments please')
386
387   re = regexp.compile('#.*')
388   cfg.read_string(re.sub('', defcfg))
389   cfg.read(opts.configfile)
390
391 def common_run():
392   log_debug(DBG.INIT, 'entering reactor')
393   if not _crashing: reactor.run()
394   print('CRASHED (end)', file=sys.stderr)