chiark / gitweb /
22c92dcf0f4b05d67a8a5e1440d57a0976c732ec
[hippotat.git] / hippotat / __init__.py
1 # -*- python -*-
2
3 import signal
4 signal.signal(signal.SIGINT, signal.SIG_DFL)
5
6 import sys
7
8 import twisted
9 from twisted.internet import reactor
10 import twisted.internet.endpoints
11 import twisted.logger
12 from twisted.logger import LogLevel
13 import twisted.python.constants
14 from twisted.python.constants import NamedConstant
15
16 import ipaddress
17 from ipaddress import AddressValueError
18
19 from optparse import OptionParser
20 from configparser import ConfigParser
21 from configparser import NoOptionError
22
23 import collections
24 import time
25 import codecs
26
27 import re as regexp
28
29 import hippotat.slip as slip
30
31 class DBG(twisted.python.constants.Names):
32   ROUTE = NamedConstant()
33   DROP = NamedConstant()
34   FLOW = NamedConstant()
35   HTTP = NamedConstant()
36   HTTP_CTRL = NamedConstant()
37   INIT = NamedConstant()
38   QUEUE = NamedConstant()
39   QUEUE_CTRL = NamedConstant()
40   HTTP_FULL = NamedConstant()
41
42 _hex_codec = codecs.getencoder('hex_codec')
43
44 log = twisted.logger.Logger()
45
46 def log_debug(dflag, msg, idof=None, d=None):
47   if idof is not None:
48     msg = '[%d] %s' % (id(idof), msg)
49   if d is not None:
50     d = d[0:64]
51     d = _hex_codec(d)[0].decode('ascii')
52     msg += ' ' + d
53   log.info('{dflag} {msgcore}', dflag=dflag, msgcore=msg)
54
55 defcfg = '''
56 [DEFAULT]
57 #[<client>] overrides
58 max_batch_down = 65536           # used by server, subject to [limits]
59 max_queue_time = 10              # used by server, subject to [limits]
60 max_request_time = 54            # used by server, subject to [limits]
61 target_requests_outstanding = 3  # must match; subject to [limits] on server
62 max_requests_outstanding = 4     # used by client
63 max_batch_up = 4000              # used by client
64 http_timeout = 30                # used by client
65 http_retry = 5                   # used by client
66
67 #[server] or [<client>] overrides
68 ipif = userv root ipif %(local)s,%(peer)s,%(mtu)s,slip %(rnets)s
69 # extra interpolations:  %(local)s        %(peer)s          %(rnet)s
70 #  obtained   on server  [virtual]server  [virtual]relay    [virtual]network
71 #      from   on client  <client>         [virtual]server   [virtual]routes
72
73 [virtual]
74 mtu = 1500
75 routes = ''
76 # network = <prefix>/<len>  # mandatory for server
77 # server  = <ipaddr>   # used by both, default is computed from `network'
78 # relay   = <ipaddr>   # used by server, default from `network' and `server'
79 #  default server is first host in network
80 #  default relay is first host which is not server
81
82 [server]
83 # addrs = 127.0.0.1 ::1    # mandatory for server
84 port = 80                  # used by server
85 # url              # used by client; default from first `addrs' and `port'
86
87 # [<client-ip4-or-ipv6-address>]
88 # password = <password>    # used by both, must match
89
90 [limits]
91 max_batch_down = 262144           # used by server
92 max_queue_time = 121              # used by server
93 max_request_time = 121            # used by server
94 target_requests_outstanding = 10  # used by server
95 '''
96
97 # these need to be defined here so that they can be imported by import *
98 cfg = ConfigParser()
99 optparser = OptionParser()
100
101 _mimetrans = bytes.maketrans(b'-'+slip.esc, slip.esc+b'-')
102 def mime_translate(s):
103   # SLIP-encoded packets cannot contain ESC ESC.
104   # Swap `-' and ESC.  The result cannot contain `--'
105   return s.translate(_mimetrans)
106
107 class ConfigResults:
108   def __init__(self, d = { }):
109     self.__dict__ = d
110   def __repr__(self):
111     return 'ConfigResults('+repr(self.__dict__)+')'
112
113 c = ConfigResults()
114
115 def log_discard(packet, iface, saddr, daddr, why):
116   log_debug(DBG.DROP,
117             'discarded packet [%s] %s -> %s: %s' % (iface, saddr, daddr, why),
118             d=packet)
119
120 #---------- packet parsing ----------
121
122 def packet_addrs(packet):
123   version = packet[0] >> 4
124   if version == 4:
125     addrlen = 4
126     saddroff = 3*4
127     factory = ipaddress.IPv4Address
128   elif version == 6:
129     addrlen = 16
130     saddroff = 2*4
131     factory = ipaddress.IPv6Address
132   else:
133     raise ValueError('unsupported IP version %d' % version)
134   saddr = factory(packet[ saddroff           : saddroff + addrlen   ])
135   daddr = factory(packet[ saddroff + addrlen : saddroff + addrlen*2 ])
136   return (saddr, daddr)
137
138 #---------- address handling ----------
139
140 def ipaddr(input):
141   try:
142     r = ipaddress.IPv4Address(input)
143   except AddressValueError:
144     r = ipaddress.IPv6Address(input)
145   return r
146
147 def ipnetwork(input):
148   try:
149     r = ipaddress.IPv4Network(input)
150   except NetworkValueError:
151     r = ipaddress.IPv6Network(input)
152   return r
153
154 #---------- ipif (SLIP) subprocess ----------
155
156 class SlipStreamDecoder():
157   def __init__(self, on_packet):
158     # we will call packet(<packet>)
159     self._buffer = b''
160     self._on_packet = on_packet
161
162   def inputdata(self, data):
163     #print('SLIP-GOT ', repr(data))
164     self._buffer += data
165     packets = slip.decode(self._buffer)
166     self._buffer = packets.pop()
167     for packet in packets:
168       self._maybe_packet(packet)
169
170   def _maybe_packet(self, packet):
171       if len(packet):
172         self._on_packet(packet)
173
174   def flush(self):
175     self._maybe_packet(self._buffer)
176     self._buffer = b''
177
178 class _IpifProcessProtocol(twisted.internet.protocol.ProcessProtocol):
179   def __init__(self, router):
180     self._router = router
181     self._decoder = SlipStreamDecoder(self.slip_on_packet)
182   def connectionMade(self): pass
183   def outReceived(self, data):
184     self._decoder.inputdata(data)
185   def slip_on_packet(self, packet):
186     (saddr, daddr) = packet_addrs(packet)
187     if saddr.is_link_local or daddr.is_link_local:
188       log_discard(packet, 'ipif', saddr, daddr, 'link-local')
189       return
190     self._router(packet, saddr, daddr)
191   def processEnded(self, status):
192     status.raiseException()
193
194 def start_ipif(command, router):
195   global ipif
196   ipif = _IpifProcessProtocol(router)
197   reactor.spawnProcess(ipif,
198                        '/bin/sh',['sh','-xc', command],
199                        childFDs={0:'w', 1:'r', 2:2},
200                        env=None)
201
202 def queue_inbound(packet):
203   ipif.transport.write(slip.delimiter)
204   ipif.transport.write(slip.encode(packet))
205   ipif.transport.write(slip.delimiter)
206
207 #---------- packet queue ----------
208
209 class PacketQueue():
210   def __init__(self, desc, max_queue_time):
211     self._desc = desc
212     self._max_queue_time = max_queue_time
213     self._pq = collections.deque() # packets
214
215   def _log(self, dflag, msg, **kwargs):
216     log_debug(dflag, self._desc+' pq: '+msg, **kwargs)
217
218   def append(self, packet):
219     self._log(DBG.QUEUE, 'append', d=packet)
220     self._pq.append((time.monotonic(), packet))
221
222   def nonempty(self):
223     self._log(DBG.QUEUE, 'nonempty ?')
224     while True:
225       try: (queuetime, packet) = self._pq[0]
226       except IndexError:
227         self._log(DBG.QUEUE, 'nonempty ? empty.')
228         return False
229
230       age = time.monotonic() - queuetime
231       if age > self._max_queue_time:
232         # strip old packets off the front
233         self._log(DBG.QUEUE, 'dropping (old)', d=packet)
234         self._pq.popleft()
235         continue
236
237       self._log(DBG.QUEUE, 'nonempty ? nonempty.')
238       return True
239
240   def process(self, sizequery, moredata, max_batch):
241     # sizequery() should return size of batch so far
242     # moredata(s) should add s to batch
243     self._log(DBG.QUEUE, 'process...')
244     while True:
245       try: (dummy, packet) = self._pq[0]
246       except IndexError:
247         self._log(DBG.QUEUE, 'process... empty')
248         break
249
250       self._log(DBG.QUEUE_CTRL, 'process... packet', d=packet)
251
252       encoded = slip.encode(packet)
253       sofar = sizequery()  
254
255       self._log(DBG.QUEUE_CTRL,
256                 'process... (sofar=%d, max=%d) encoded' % (sofar, max_batch),
257                 d=encoded)
258
259       if sofar > 0:
260         if sofar + len(slip.delimiter) + len(encoded) > max_batch:
261           self._log(DBG.QUEUE_CTRL, 'process... overflow')
262           break
263         moredata(slip.delimiter)
264
265       moredata(encoded)
266       self._pq.popleft()
267
268 #---------- error handling ----------
269
270 _crashing = False
271
272 def crash(err):
273   global _crashing
274   _crashing = True
275   print('CRASH ', err, file=sys.stderr)
276   try: reactor.stop()
277   except twisted.internet.error.ReactorNotRunning: pass
278
279 def crash_on_defer(defer):
280   defer.addErrback(lambda err: crash(err))
281
282 def crash_on_critical(event):
283   if event.get('log_level') >= LogLevel.critical:
284     crash(twisted.logger.formatEvent(event))
285
286 #---------- config processing ----------
287
288 def process_cfg_common_always():
289   global mtu
290   c.mtu = cfg.get('virtual','mtu')
291
292 def process_cfg_ipif(section, varmap):
293   for d, s in varmap:
294     try: v = getattr(c, s)
295     except AttributeError: continue
296     setattr(c, d, v)
297
298   #print(repr(c))
299
300   c.ipif_command = cfg.get(section,'ipif', vars=c.__dict__)
301
302 def process_cfg_network():
303   c.network = ipnetwork(cfg.get('virtual','network'))
304   if c.network.num_addresses < 3 + 2:
305     raise ValueError('network needs at least 2^3 addresses')
306
307 def process_cfg_server():
308   try:
309     c.server = cfg.get('virtual','server')
310   except NoOptionError:
311     process_cfg_network()
312     c.server = next(c.network.hosts())
313
314 class ServerAddr():
315   def __init__(self, port, addrspec):
316     self.port = port
317     # also self.addr
318     try:
319       self.addr = ipaddress.IPv4Address(addrspec)
320       self._endpointfactory = twisted.internet.endpoints.TCP4ServerEndpoint
321       self._inurl = b'%s'
322     except AddressValueError:
323       self.addr = ipaddress.IPv6Address(addrspec)
324       self._endpointfactory = twisted.internet.endpoints.TCP6ServerEndpoint
325       self._inurl = b'[%s]'
326   def make_endpoint(self):
327     return self._endpointfactory(reactor, self.port, self.addr)
328   def url(self):
329     url = b'http://' + (self._inurl % str(self.addr).encode('ascii'))
330     if self.port != 80: url += b':%d' % self.port
331     url += b'/'
332     return url
333     
334 def process_cfg_saddrs():
335   try: port = cfg.getint('server','port')
336   except NoOptionError: port = 80
337
338   c.saddrs = [ ]
339   for addrspec in cfg.get('server','addrs').split():
340     sa = ServerAddr(port, addrspec)
341     c.saddrs.append(sa)
342
343 def process_cfg_clients(constructor):
344   c.clients = [ ]
345   for cs in cfg.sections():
346     if not (':' in cs or '.' in cs): continue
347     ci = ipaddr(cs)
348     pw = cfg.get(cs, 'password')
349     pw = pw.encode('utf-8')
350     constructor(ci,cs,pw)
351
352 #---------- startup ----------
353
354 def common_startup():
355   log_formatter = twisted.logger.formatEventAsClassicLogText
356   log_observer = twisted.logger.FileLogObserver(sys.stderr, log_formatter)
357   twisted.logger.globalLogBeginner.beginLoggingTo(
358     [ log_observer, crash_on_critical ]
359     )
360
361   optparser.add_option('-c', '--config', dest='configfile',
362                        default='/etc/hippotat/config')
363   (opts, args) = optparser.parse_args()
364   if len(args): optparser.error('no non-option arguments please')
365
366   re = regexp.compile('#.*')
367   cfg.read_string(re.sub('', defcfg))
368   cfg.read(opts.configfile)
369
370 def common_run():
371   log_debug(DBG.INIT, 'entering reactor')
372   if not _crashing: reactor.run()
373   print('CRASHED (end)', file=sys.stderr)