chiark / gitweb /
it pings!
[hippotat.git] / hippotat / __init__.py
1 # -*- python -*-
2
3 import signal
4 signal.signal(signal.SIGINT, signal.SIG_DFL)
5
6 import sys
7
8 import twisted
9 from twisted.internet import reactor
10 import twisted.internet.endpoints
11 import twisted.logger
12 from twisted.logger import LogLevel
13 import twisted.python.constants
14 from twisted.python.constants import NamedConstant
15
16 import ipaddress
17 from ipaddress import AddressValueError
18
19 from optparse import OptionParser
20 from configparser import ConfigParser
21 from configparser import NoOptionError
22
23 from functools import partial
24
25 import collections
26 import time
27 import codecs
28 import traceback
29
30 import re as regexp
31
32 import hippotat.slip as slip
33
34 class DBG(twisted.python.constants.Names):
35   ROUTE = NamedConstant()
36   DROP = NamedConstant()
37   FLOW = NamedConstant()
38   HTTP = NamedConstant()
39   HTTP_CTRL = NamedConstant()
40   INIT = NamedConstant()
41   QUEUE = NamedConstant()
42   QUEUE_CTRL = NamedConstant()
43   HTTP_FULL = NamedConstant()
44   SLIP_FULL = NamedConstant()
45   CTRL_DUMP = NamedConstant()
46
47 _hex_codec = codecs.getencoder('hex_codec')
48
49 log = twisted.logger.Logger()
50
51 def log_debug(dflag, msg, idof=None, d=None):
52   #print('---------------->',repr((dflag, msg, idof, d)), file=sys.stderr)
53   if idof is not None:
54     msg = '[%#x] %s' % (id(idof), msg)
55   if d is not None:
56     d = d[0:64]
57     d = _hex_codec(d)[0].decode('ascii')
58     msg += ' ' + d
59   log.info('{dflag} {msgcore}', dflag=dflag, msgcore=msg)
60
61 defcfg = '''
62 [DEFAULT]
63 #[<client>] overrides
64 max_batch_down = 65536           # used by server, subject to [limits]
65 max_queue_time = 10              # used by server, subject to [limits]
66 max_request_time = 54            # used by server, subject to [limits]
67 target_requests_outstanding = 3  # must match; subject to [limits] on server
68 max_requests_outstanding = 4     # used by client
69 max_batch_up = 4000              # used by client
70 http_timeout = 30                # used by client
71 http_retry = 5                   # used by client
72
73 #[server] or [<client>] overrides
74 ipif = userv root ipif %(local)s,%(peer)s,%(mtu)s,slip %(rnets)s
75 # extra interpolations:  %(local)s        %(peer)s          %(rnet)s
76 #  obtained   on server  [virtual]server  [virtual]relay    [virtual]network
77 #      from   on client  <client>         [virtual]server   [virtual]routes
78
79 [virtual]
80 mtu = 1500
81 routes = ''
82 # network = <prefix>/<len>  # mandatory for server
83 # server  = <ipaddr>   # used by both, default is computed from `network'
84 # relay   = <ipaddr>   # used by server, default from `network' and `server'
85 #  default server is first host in network
86 #  default relay is first host which is not server
87
88 [server]
89 # addrs = 127.0.0.1 ::1    # mandatory for server
90 port = 80                  # used by server
91 # url              # used by client; default from first `addrs' and `port'
92
93 # [<client-ip4-or-ipv6-address>]
94 # password = <password>    # used by both, must match
95
96 [limits]
97 max_batch_down = 262144           # used by server
98 max_queue_time = 121              # used by server
99 max_request_time = 121            # used by server
100 target_requests_outstanding = 10  # used by server
101 '''
102
103 # these need to be defined here so that they can be imported by import *
104 cfg = ConfigParser()
105 optparser = OptionParser()
106
107 _mimetrans = bytes.maketrans(b'-'+slip.esc, slip.esc+b'-')
108 def mime_translate(s):
109   # SLIP-encoded packets cannot contain ESC ESC.
110   # Swap `-' and ESC.  The result cannot contain `--'
111   return s.translate(_mimetrans)
112
113 class ConfigResults:
114   def __init__(self, d = { }):
115     self.__dict__ = d
116   def __repr__(self):
117     return 'ConfigResults('+repr(self.__dict__)+')'
118
119 c = ConfigResults()
120
121 def log_discard(packet, iface, saddr, daddr, why):
122   log_debug(DBG.DROP,
123             'discarded packet [%s] %s -> %s: %s' % (iface, saddr, daddr, why),
124             d=packet)
125
126 #---------- packet parsing ----------
127
128 def packet_addrs(packet):
129   version = packet[0] >> 4
130   if version == 4:
131     addrlen = 4
132     saddroff = 3*4
133     factory = ipaddress.IPv4Address
134   elif version == 6:
135     addrlen = 16
136     saddroff = 2*4
137     factory = ipaddress.IPv6Address
138   else:
139     raise ValueError('unsupported IP version %d' % version)
140   saddr = factory(packet[ saddroff           : saddroff + addrlen   ])
141   daddr = factory(packet[ saddroff + addrlen : saddroff + addrlen*2 ])
142   return (saddr, daddr)
143
144 #---------- address handling ----------
145
146 def ipaddr(input):
147   try:
148     r = ipaddress.IPv4Address(input)
149   except AddressValueError:
150     r = ipaddress.IPv6Address(input)
151   return r
152
153 def ipnetwork(input):
154   try:
155     r = ipaddress.IPv4Network(input)
156   except NetworkValueError:
157     r = ipaddress.IPv6Network(input)
158   return r
159
160 #---------- ipif (SLIP) subprocess ----------
161
162 class SlipStreamDecoder():
163   def __init__(self, desc, on_packet):
164     self._buffer = b''
165     self._on_packet = on_packet
166     self._desc = desc
167     self._log('__init__')
168
169   def _log(self, msg, **kwargs):
170     log_debug(DBG.SLIP_FULL, 'slip %s: %s' % (self._desc, msg), **kwargs)
171
172   def inputdata(self, data):
173     self._log('inputdata', d=data)
174     data = self._buffer + data
175     self._buffer = b''
176     packets = slip.decode(data)
177     self._buffer = packets.pop()
178     for packet in packets:
179       self._maybe_packet(packet)
180     self._log('bufremain', d=self._buffer)
181
182   def _maybe_packet(self, packet):
183     self._log('maybepacket', d=packet)
184     if len(packet):
185       self._on_packet(packet)
186
187   def flush(self):
188     self._log('flush')
189     self._maybe_packet(self._buffer)
190     self._buffer = b''
191
192 class _IpifProcessProtocol(twisted.internet.protocol.ProcessProtocol):
193   def __init__(self, router):
194     self._router = router
195     self._decoder = SlipStreamDecoder('ipif', self.slip_on_packet)
196   def connectionMade(self): pass
197   def outReceived(self, data):
198     self._decoder.inputdata(data)
199   def slip_on_packet(self, packet):
200     (saddr, daddr) = packet_addrs(packet)
201     if saddr.is_link_local or daddr.is_link_local:
202       log_discard(packet, 'ipif', saddr, daddr, 'link-local')
203       return
204     self._router(packet, saddr, daddr)
205   def processEnded(self, status):
206     status.raiseException()
207
208 def start_ipif(command, router):
209   global ipif
210   ipif = _IpifProcessProtocol(router)
211   reactor.spawnProcess(ipif,
212                        '/bin/sh',['sh','-xc', command],
213                        childFDs={0:'w', 1:'r', 2:2},
214                        env=None)
215
216 def queue_inbound(packet):
217   log_debug(DBG.FLOW, "queue_inbound", d=packet)
218   ipif.transport.write(slip.delimiter)
219   ipif.transport.write(slip.encode(packet))
220   ipif.transport.write(slip.delimiter)
221
222 #---------- packet queue ----------
223
224 class PacketQueue():
225   def __init__(self, desc, max_queue_time):
226     self._desc = desc
227     assert(desc + '')
228     self._max_queue_time = max_queue_time
229     self._pq = collections.deque() # packets
230
231   def _log(self, dflag, msg, **kwargs):
232     log_debug(dflag, self._desc+' pq: '+msg, **kwargs)
233
234   def append(self, packet):
235     self._log(DBG.QUEUE, 'append', d=packet)
236     self._pq.append((time.monotonic(), packet))
237
238   def nonempty(self):
239     self._log(DBG.QUEUE, 'nonempty ?')
240     while True:
241       try: (queuetime, packet) = self._pq[0]
242       except IndexError:
243         self._log(DBG.QUEUE, 'nonempty ? empty.')
244         return False
245
246       age = time.monotonic() - queuetime
247       if age > self._max_queue_time:
248         # strip old packets off the front
249         self._log(DBG.QUEUE, 'dropping (old)', d=packet)
250         self._pq.popleft()
251         continue
252
253       self._log(DBG.QUEUE, 'nonempty ? nonempty.')
254       return True
255
256   def process(self, sizequery, moredata, max_batch):
257     # sizequery() should return size of batch so far
258     # moredata(s) should add s to batch
259     self._log(DBG.QUEUE, 'process...')
260     while True:
261       try: (dummy, packet) = self._pq[0]
262       except IndexError:
263         self._log(DBG.QUEUE, 'process... empty')
264         break
265
266       self._log(DBG.QUEUE_CTRL, 'process... packet', d=packet)
267
268       encoded = slip.encode(packet)
269       sofar = sizequery()  
270
271       self._log(DBG.QUEUE_CTRL,
272                 'process... (sofar=%d, max=%d) encoded' % (sofar, max_batch),
273                 d=encoded)
274
275       if sofar > 0:
276         if sofar + len(slip.delimiter) + len(encoded) > max_batch:
277           self._log(DBG.QUEUE_CTRL, 'process... overflow')
278           break
279         moredata(slip.delimiter)
280
281       moredata(encoded)
282       self._pq.popleft()
283
284 #---------- error handling ----------
285
286 _crashing = False
287
288 def crash(err):
289   global _crashing
290   _crashing = True
291   print('========== CRASH ==========', err,
292         '===========================', file=sys.stderr)
293   try: reactor.stop()
294   except twisted.internet.error.ReactorNotRunning: pass
295
296 def crash_on_defer(defer):
297   defer.addErrback(lambda err: crash(err))
298
299 def crash_on_critical(event):
300   if event.get('log_level') >= LogLevel.critical:
301     crash(twisted.logger.formatEvent(event))
302
303 #---------- config processing ----------
304
305 def process_cfg_common_always():
306   global mtu
307   c.mtu = cfg.get('virtual','mtu')
308
309 def process_cfg_ipif(section, varmap):
310   for d, s in varmap:
311     try: v = getattr(c, s)
312     except AttributeError: continue
313     setattr(c, d, v)
314
315   #print(repr(c))
316
317   c.ipif_command = cfg.get(section,'ipif', vars=c.__dict__)
318
319 def process_cfg_network():
320   c.network = ipnetwork(cfg.get('virtual','network'))
321   if c.network.num_addresses < 3 + 2:
322     raise ValueError('network needs at least 2^3 addresses')
323
324 def process_cfg_server():
325   try:
326     c.server = cfg.get('virtual','server')
327   except NoOptionError:
328     process_cfg_network()
329     c.server = next(c.network.hosts())
330
331 class ServerAddr():
332   def __init__(self, port, addrspec):
333     self.port = port
334     # also self.addr
335     try:
336       self.addr = ipaddress.IPv4Address(addrspec)
337       self._endpointfactory = twisted.internet.endpoints.TCP4ServerEndpoint
338       self._inurl = b'%s'
339     except AddressValueError:
340       self.addr = ipaddress.IPv6Address(addrspec)
341       self._endpointfactory = twisted.internet.endpoints.TCP6ServerEndpoint
342       self._inurl = b'[%s]'
343   def make_endpoint(self):
344     return self._endpointfactory(reactor, self.port, self.addr)
345   def url(self):
346     url = b'http://' + (self._inurl % str(self.addr).encode('ascii'))
347     if self.port != 80: url += b':%d' % self.port
348     url += b'/'
349     return url
350     
351 def process_cfg_saddrs():
352   try: port = cfg.getint('server','port')
353   except NoOptionError: port = 80
354
355   c.saddrs = [ ]
356   for addrspec in cfg.get('server','addrs').split():
357     sa = ServerAddr(port, addrspec)
358     c.saddrs.append(sa)
359
360 def process_cfg_clients(constructor):
361   c.clients = [ ]
362   for cs in cfg.sections():
363     if not (':' in cs or '.' in cs): continue
364     ci = ipaddr(cs)
365     pw = cfg.get(cs, 'password')
366     pw = pw.encode('utf-8')
367     constructor(ci,cs,pw)
368
369 #---------- startup ----------
370
371 def common_startup():
372   log_formatter = twisted.logger.formatEventAsClassicLogText
373   log_observer = twisted.logger.FileLogObserver(sys.stderr, log_formatter)
374   twisted.logger.globalLogBeginner.beginLoggingTo(
375     [ log_observer, crash_on_critical ]
376     )
377
378   optparser.add_option('-c', '--config', dest='configfile',
379                        default='/etc/hippotat/config')
380   (opts, args) = optparser.parse_args()
381   if len(args): optparser.error('no non-option arguments please')
382
383   re = regexp.compile('#.*')
384   cfg.read_string(re.sub('', defcfg))
385   cfg.read(opts.configfile)
386
387 def common_run():
388   log_debug(DBG.INIT, 'entering reactor')
389   if not _crashing: reactor.run()
390   print('CRASHED (end)', file=sys.stderr)