chiark / gitweb /
before undo GeneralResponseConsumer
[hippotat.git] / hippotat / __init__.py
1 # -*- python -*-
2
3 import signal
4 signal.signal(signal.SIGINT, signal.SIG_DFL)
5
6 import sys
7
8 import twisted
9 from twisted.internet import reactor
10 import twisted.internet.endpoints
11 import twisted.logger
12 from twisted.logger import LogLevel
13 import twisted.python.constants
14 from twisted.python.constants import NamedConstant
15
16 import ipaddress
17 from ipaddress import AddressValueError
18
19 from optparse import OptionParser
20 from configparser import ConfigParser
21 from configparser import NoOptionError
22
23 import collections
24 import time
25 import codecs
26 import traceback
27
28 import re as regexp
29
30 import hippotat.slip as slip
31
32 class DBG(twisted.python.constants.Names):
33   ROUTE = NamedConstant()
34   DROP = NamedConstant()
35   FLOW = NamedConstant()
36   HTTP = NamedConstant()
37   HTTP_CTRL = NamedConstant()
38   INIT = NamedConstant()
39   QUEUE = NamedConstant()
40   QUEUE_CTRL = NamedConstant()
41   HTTP_FULL = NamedConstant()
42   SLIP_FULL = NamedConstant()
43   CTRL_DUMP = NamedConstant()
44
45 _hex_codec = codecs.getencoder('hex_codec')
46
47 log = twisted.logger.Logger()
48
49 def log_debug(dflag, msg, idof=None, d=None):
50   #print('---------------->',repr((dflag, msg, idof, d)), file=sys.stderr)
51   if idof is not None:
52     msg = '[%d] %s' % (id(idof), msg)
53   if d is not None:
54     d = d[0:64]
55     d = _hex_codec(d)[0].decode('ascii')
56     msg += ' ' + d
57   log.info('{dflag} {msgcore}', dflag=dflag, msgcore=msg)
58
59 defcfg = '''
60 [DEFAULT]
61 #[<client>] overrides
62 max_batch_down = 65536           # used by server, subject to [limits]
63 max_queue_time = 10              # used by server, subject to [limits]
64 max_request_time = 54            # used by server, subject to [limits]
65 target_requests_outstanding = 3  # must match; subject to [limits] on server
66 max_requests_outstanding = 4     # used by client
67 max_batch_up = 4000              # used by client
68 http_timeout = 30                # used by client
69 http_retry = 5                   # used by client
70
71 #[server] or [<client>] overrides
72 ipif = userv root ipif %(local)s,%(peer)s,%(mtu)s,slip %(rnets)s
73 # extra interpolations:  %(local)s        %(peer)s          %(rnet)s
74 #  obtained   on server  [virtual]server  [virtual]relay    [virtual]network
75 #      from   on client  <client>         [virtual]server   [virtual]routes
76
77 [virtual]
78 mtu = 1500
79 routes = ''
80 # network = <prefix>/<len>  # mandatory for server
81 # server  = <ipaddr>   # used by both, default is computed from `network'
82 # relay   = <ipaddr>   # used by server, default from `network' and `server'
83 #  default server is first host in network
84 #  default relay is first host which is not server
85
86 [server]
87 # addrs = 127.0.0.1 ::1    # mandatory for server
88 port = 80                  # used by server
89 # url              # used by client; default from first `addrs' and `port'
90
91 # [<client-ip4-or-ipv6-address>]
92 # password = <password>    # used by both, must match
93
94 [limits]
95 max_batch_down = 262144           # used by server
96 max_queue_time = 121              # used by server
97 max_request_time = 121            # used by server
98 target_requests_outstanding = 10  # used by server
99 '''
100
101 # these need to be defined here so that they can be imported by import *
102 cfg = ConfigParser()
103 optparser = OptionParser()
104
105 _mimetrans = bytes.maketrans(b'-'+slip.esc, slip.esc+b'-')
106 def mime_translate(s):
107   # SLIP-encoded packets cannot contain ESC ESC.
108   # Swap `-' and ESC.  The result cannot contain `--'
109   return s.translate(_mimetrans)
110
111 class ConfigResults:
112   def __init__(self, d = { }):
113     self.__dict__ = d
114   def __repr__(self):
115     return 'ConfigResults('+repr(self.__dict__)+')'
116
117 c = ConfigResults()
118
119 def log_discard(packet, iface, saddr, daddr, why):
120   log_debug(DBG.DROP,
121             'discarded packet [%s] %s -> %s: %s' % (iface, saddr, daddr, why),
122             d=packet)
123
124 #---------- packet parsing ----------
125
126 def packet_addrs(packet):
127   version = packet[0] >> 4
128   if version == 4:
129     addrlen = 4
130     saddroff = 3*4
131     factory = ipaddress.IPv4Address
132   elif version == 6:
133     addrlen = 16
134     saddroff = 2*4
135     factory = ipaddress.IPv6Address
136   else:
137     raise ValueError('unsupported IP version %d' % version)
138   saddr = factory(packet[ saddroff           : saddroff + addrlen   ])
139   daddr = factory(packet[ saddroff + addrlen : saddroff + addrlen*2 ])
140   return (saddr, daddr)
141
142 #---------- address handling ----------
143
144 def ipaddr(input):
145   try:
146     r = ipaddress.IPv4Address(input)
147   except AddressValueError:
148     r = ipaddress.IPv6Address(input)
149   return r
150
151 def ipnetwork(input):
152   try:
153     r = ipaddress.IPv4Network(input)
154   except NetworkValueError:
155     r = ipaddress.IPv6Network(input)
156   return r
157
158 #---------- ipif (SLIP) subprocess ----------
159
160 class SlipStreamDecoder():
161   def __init__(self, desc, on_packet):
162     self._buffer = b''
163     self._on_packet = on_packet
164     self._desc = desc
165     self._log('__init__')
166
167   def _log(self, msg, **kwargs):
168     log_debug(DBG.SLIP_FULL, 'slip %s: %s' % (self._desc, msg), **kwargs)
169
170   def inputdata(self, data):
171     self._log('inputdata', d=data)
172     data = self._buffer + data
173     self._buffer = b''
174     packets = slip.decode(data)
175     self._buffer = packets.pop()
176     for packet in packets:
177       self._maybe_packet(packet)
178     self._log('bufremain', d=self._buffer)
179
180   def _maybe_packet(self, packet):
181     self._log('maybepacket', d=packet)
182     if len(packet):
183       self._on_packet(packet)
184
185   def flush(self):
186     self._log('flush')
187     self._maybe_packet(self._buffer)
188     self._buffer = b''
189
190 class _IpifProcessProtocol(twisted.internet.protocol.ProcessProtocol):
191   def __init__(self, router):
192     self._router = router
193     self._decoder = SlipStreamDecoder('ipif', self.slip_on_packet)
194   def connectionMade(self): pass
195   def outReceived(self, data):
196     self._decoder.inputdata(data)
197   def slip_on_packet(self, packet):
198     (saddr, daddr) = packet_addrs(packet)
199     if saddr.is_link_local or daddr.is_link_local:
200       log_discard(packet, 'ipif', saddr, daddr, 'link-local')
201       return
202     self._router(packet, saddr, daddr)
203   def processEnded(self, status):
204     status.raiseException()
205
206 def start_ipif(command, router):
207   global ipif
208   ipif = _IpifProcessProtocol(router)
209   reactor.spawnProcess(ipif,
210                        '/bin/sh',['sh','-xc', command],
211                        childFDs={0:'w', 1:'r', 2:2},
212                        env=None)
213
214 def queue_inbound(packet):
215   log_debug(DBG.FLOW, "queue_inbound", d=packet)
216   ipif.transport.write(slip.delimiter)
217   ipif.transport.write(slip.encode(packet))
218   ipif.transport.write(slip.delimiter)
219
220 #---------- packet queue ----------
221
222 class PacketQueue():
223   def __init__(self, desc, max_queue_time):
224     self._desc = desc
225     assert(desc + '')
226     self._max_queue_time = max_queue_time
227     self._pq = collections.deque() # packets
228
229   def _log(self, dflag, msg, **kwargs):
230     log_debug(dflag, self._desc+' pq: '+msg, **kwargs)
231
232   def append(self, packet):
233     self._log(DBG.QUEUE, 'append', d=packet)
234     self._pq.append((time.monotonic(), packet))
235
236   def nonempty(self):
237     self._log(DBG.QUEUE, 'nonempty ?')
238     while True:
239       try: (queuetime, packet) = self._pq[0]
240       except IndexError:
241         self._log(DBG.QUEUE, 'nonempty ? empty.')
242         return False
243
244       age = time.monotonic() - queuetime
245       if age > self._max_queue_time:
246         # strip old packets off the front
247         self._log(DBG.QUEUE, 'dropping (old)', d=packet)
248         self._pq.popleft()
249         continue
250
251       self._log(DBG.QUEUE, 'nonempty ? nonempty.')
252       return True
253
254   def process(self, sizequery, moredata, max_batch):
255     # sizequery() should return size of batch so far
256     # moredata(s) should add s to batch
257     self._log(DBG.QUEUE, 'process...')
258     while True:
259       try: (dummy, packet) = self._pq[0]
260       except IndexError:
261         self._log(DBG.QUEUE, 'process... empty')
262         break
263
264       self._log(DBG.QUEUE_CTRL, 'process... packet', d=packet)
265
266       encoded = slip.encode(packet)
267       sofar = sizequery()  
268
269       self._log(DBG.QUEUE_CTRL,
270                 'process... (sofar=%d, max=%d) encoded' % (sofar, max_batch),
271                 d=encoded)
272
273       if sofar > 0:
274         if sofar + len(slip.delimiter) + len(encoded) > max_batch:
275           self._log(DBG.QUEUE_CTRL, 'process... overflow')
276           break
277         moredata(slip.delimiter)
278
279       moredata(encoded)
280       self._pq.popleft()
281
282 #---------- error handling ----------
283
284 _crashing = False
285
286 def crash(err):
287   global _crashing
288   _crashing = True
289   print('CRASH ', err, file=sys.stderr)
290   try: reactor.stop()
291   except twisted.internet.error.ReactorNotRunning: pass
292
293 def crash_on_defer(defer):
294   defer.addErrback(lambda err: crash(err))
295
296 def crash_on_critical(event):
297   if event.get('log_level') >= LogLevel.critical:
298     crash(twisted.logger.formatEvent(event))
299
300 #---------- config processing ----------
301
302 def process_cfg_common_always():
303   global mtu
304   c.mtu = cfg.get('virtual','mtu')
305
306 def process_cfg_ipif(section, varmap):
307   for d, s in varmap:
308     try: v = getattr(c, s)
309     except AttributeError: continue
310     setattr(c, d, v)
311
312   #print(repr(c))
313
314   c.ipif_command = cfg.get(section,'ipif', vars=c.__dict__)
315
316 def process_cfg_network():
317   c.network = ipnetwork(cfg.get('virtual','network'))
318   if c.network.num_addresses < 3 + 2:
319     raise ValueError('network needs at least 2^3 addresses')
320
321 def process_cfg_server():
322   try:
323     c.server = cfg.get('virtual','server')
324   except NoOptionError:
325     process_cfg_network()
326     c.server = next(c.network.hosts())
327
328 class ServerAddr():
329   def __init__(self, port, addrspec):
330     self.port = port
331     # also self.addr
332     try:
333       self.addr = ipaddress.IPv4Address(addrspec)
334       self._endpointfactory = twisted.internet.endpoints.TCP4ServerEndpoint
335       self._inurl = b'%s'
336     except AddressValueError:
337       self.addr = ipaddress.IPv6Address(addrspec)
338       self._endpointfactory = twisted.internet.endpoints.TCP6ServerEndpoint
339       self._inurl = b'[%s]'
340   def make_endpoint(self):
341     return self._endpointfactory(reactor, self.port, self.addr)
342   def url(self):
343     url = b'http://' + (self._inurl % str(self.addr).encode('ascii'))
344     if self.port != 80: url += b':%d' % self.port
345     url += b'/'
346     return url
347     
348 def process_cfg_saddrs():
349   try: port = cfg.getint('server','port')
350   except NoOptionError: port = 80
351
352   c.saddrs = [ ]
353   for addrspec in cfg.get('server','addrs').split():
354     sa = ServerAddr(port, addrspec)
355     c.saddrs.append(sa)
356
357 def process_cfg_clients(constructor):
358   c.clients = [ ]
359   for cs in cfg.sections():
360     if not (':' in cs or '.' in cs): continue
361     ci = ipaddr(cs)
362     pw = cfg.get(cs, 'password')
363     pw = pw.encode('utf-8')
364     constructor(ci,cs,pw)
365
366 #---------- startup ----------
367
368 def common_startup():
369   log_formatter = twisted.logger.formatEventAsClassicLogText
370   log_observer = twisted.logger.FileLogObserver(sys.stderr, log_formatter)
371   twisted.logger.globalLogBeginner.beginLoggingTo(
372     [ log_observer, crash_on_critical ]
373     )
374
375   optparser.add_option('-c', '--config', dest='configfile',
376                        default='/etc/hippotat/config')
377   (opts, args) = optparser.parse_args()
378   if len(args): optparser.error('no non-option arguments please')
379
380   re = regexp.compile('#.*')
381   cfg.read_string(re.sub('', defcfg))
382   cfg.read(opts.configfile)
383
384 def common_run():
385   log_debug(DBG.INIT, 'entering reactor')
386   if not _crashing: reactor.run()
387   print('CRASHED (end)', file=sys.stderr)