chiark / gitweb /
fixes
[hippotat.git] / hippotat / __init__.py
1 # -*- python -*-
2
3 import signal
4 signal.signal(signal.SIGINT, signal.SIG_DFL)
5
6 import sys
7
8 import twisted
9 from twisted.internet import reactor
10 import twisted.internet.endpoints
11 import twisted.logger
12 from twisted.logger import LogLevel
13 import twisted.python.constants
14 from twisted.python.constants import NamedConstant
15
16 import ipaddress
17 from ipaddress import AddressValueError
18
19 from optparse import OptionParser
20 from configparser import ConfigParser
21 from configparser import NoOptionError
22
23 import collections
24 import time
25 import codecs
26
27 import re as regexp
28
29 import hippotat.slip as slip
30
31 class DBG(twisted.python.constants.Names):
32   ROUTE = NamedConstant()
33   DROP = NamedConstant()
34   FLOW = NamedConstant()
35   HTTP = NamedConstant()
36   HTTP_CTRL = NamedConstant()
37   INIT = NamedConstant()
38   QUEUE = NamedConstant()
39   QUEUE_CTRL = NamedConstant()
40   HTTP_FULL = NamedConstant()
41
42 _hex_codec = codecs.getencoder('hex_codec')
43
44 log = twisted.logger.Logger()
45
46 def log_debug(dflag, msg, idof=None, d=None):
47   if idof is not None:
48     msg = '[%d] %s' % (id(idof), msg)
49   if d is not None:
50     d = d[0:64]
51     d = _hex_codec(d)[0].decode('ascii')
52     msg += ' ' + d
53   log.info('{dflag} {msgcore}', dflag=dflag, msgcore=msg)
54
55 defcfg = '''
56 [DEFAULT]
57 #[<client>] overrides
58 max_batch_down = 65536           # used by server, subject to [limits]
59 max_queue_time = 10              # used by server, subject to [limits]
60 max_request_time = 54            # used by server, subject to [limits]
61 target_requests_outstanding = 3  # must match; subject to [limits] on server
62 max_requests_outstanding = 4     # used by client
63 max_batch_up = 4000              # used by client
64 http_timeout = 30                # used by client
65 http_retry = 5                   # used by client
66
67 #[server] or [<client>] overrides
68 ipif = userv root ipif %(local)s,%(peer)s,%(mtu)s,slip %(rnets)s
69 # extra interpolations:  %(local)s        %(peer)s          %(rnet)s
70 #  obtained   on server  [virtual]server  [virtual]relay    [virtual]network
71 #      from   on client  <client>         [virtual]server   [virtual]routes
72
73 [virtual]
74 mtu = 1500
75 routes = ''
76 # network = <prefix>/<len>  # mandatory for server
77 # server  = <ipaddr>   # used by both, default is computed from `network'
78 # relay   = <ipaddr>   # used by server, default from `network' and `server'
79 #  default server is first host in network
80 #  default relay is first host which is not server
81
82 [server]
83 # addrs = 127.0.0.1 ::1    # mandatory for server
84 port = 80                  # used by server
85 # url              # used by client; default from first `addrs' and `port'
86
87 # [<client-ip4-or-ipv6-address>]
88 # password = <password>    # used by both, must match
89
90 [limits]
91 max_batch_down = 262144           # used by server
92 max_queue_time = 121              # used by server
93 max_request_time = 121            # used by server
94 target_requests_outstanding = 10  # used by server
95 '''
96
97 # these need to be defined here so that they can be imported by import *
98 cfg = ConfigParser()
99 optparser = OptionParser()
100
101 _mimetrans = bytes.maketrans(b'-'+slip.esc, slip.esc+b'-')
102 def mime_translate(s):
103   # SLIP-encoded packets cannot contain ESC ESC.
104   # Swap `-' and ESC.  The result cannot contain `--'
105   return s.translate(_mimetrans)
106
107 class ConfigResults:
108   def __init__(self, d = { }):
109     self.__dict__ = d
110   def __repr__(self):
111     return 'ConfigResults('+repr(self.__dict__)+')'
112
113 c = ConfigResults()
114
115 def log_discard(packet, iface, saddr, daddr, why):
116   log_debug(DBG.DROP,
117             'discarded packet [%s] %s -> %s: %s' % (iface, saddr, daddr, why),
118             d=packet)
119
120 #---------- packet parsing ----------
121
122 def packet_addrs(packet):
123   version = packet[0] >> 4
124   if version == 4:
125     addrlen = 4
126     saddroff = 3*4
127     factory = ipaddress.IPv4Address
128   elif version == 6:
129     addrlen = 16
130     saddroff = 2*4
131     factory = ipaddress.IPv6Address
132   else:
133     raise ValueError('unsupported IP version %d' % version)
134   saddr = factory(packet[ saddroff           : saddroff + addrlen   ])
135   daddr = factory(packet[ saddroff + addrlen : saddroff + addrlen*2 ])
136   return (saddr, daddr)
137
138 #---------- address handling ----------
139
140 def ipaddr(input):
141   try:
142     r = ipaddress.IPv4Address(input)
143   except AddressValueError:
144     r = ipaddress.IPv6Address(input)
145   return r
146
147 def ipnetwork(input):
148   try:
149     r = ipaddress.IPv4Network(input)
150   except NetworkValueError:
151     r = ipaddress.IPv6Network(input)
152   return r
153
154 #---------- ipif (SLIP) subprocess ----------
155
156 class SlipStreamDecoder():
157   def __init__(self, on_packet):
158     # we will call packet(<packet>)
159     self._buffer = b''
160     self._on_packet = on_packet
161
162   def inputdata(self, data):
163     #print('SLIP-GOT ', repr(data))
164     self._buffer += data
165     packets = slip.decode(self._buffer)
166     self._buffer = packets.pop()
167     for packet in packets:
168       self._maybe_packet(packet)
169
170   def _maybe_packet(self, packet):
171       if len(packet):
172         self._on_packet(packet)
173
174   def flush(self):
175     self._maybe_packet(self._buffer)
176     self._buffer = b''
177
178 class _IpifProcessProtocol(twisted.internet.protocol.ProcessProtocol):
179   def __init__(self, router):
180     self._router = router
181     self._decoder = SlipStreamDecoder(self.slip_on_packet)
182   def connectionMade(self): pass
183   def outReceived(self, data):
184     self._decoder.inputdata(data)
185   def slip_on_packet(self, packet):
186     (saddr, daddr) = packet_addrs(packet)
187     if saddr.is_link_local or daddr.is_link_local:
188       log_discard(packet, 'ipif', saddr, daddr, 'link-local')
189       return
190     self._router(packet, saddr, daddr)
191   def processEnded(self, status):
192     status.raiseException()
193
194 def start_ipif(command, router):
195   global ipif
196   ipif = _IpifProcessProtocol(router)
197   reactor.spawnProcess(ipif,
198                        '/bin/sh',['sh','-xc', command],
199                        childFDs={0:'w', 1:'r', 2:2},
200                        env=None)
201
202 def queue_inbound(packet):
203   ipif.transport.write(slip.delimiter)
204   ipif.transport.write(slip.encode(packet))
205   ipif.transport.write(slip.delimiter)
206
207 #---------- packet queue ----------
208
209 class PacketQueue():
210   def __init__(self, desc, max_queue_time):
211     self._desc = desc
212     assert(desc + '')
213     self._max_queue_time = max_queue_time
214     self._pq = collections.deque() # packets
215
216   def _log(self, dflag, msg, **kwargs):
217     log_debug(dflag, self._desc+' pq: '+msg, **kwargs)
218
219   def append(self, packet):
220     self._log(DBG.QUEUE, 'append', d=packet)
221     self._pq.append((time.monotonic(), packet))
222
223   def nonempty(self):
224     self._log(DBG.QUEUE, 'nonempty ?')
225     while True:
226       try: (queuetime, packet) = self._pq[0]
227       except IndexError:
228         self._log(DBG.QUEUE, 'nonempty ? empty.')
229         return False
230
231       age = time.monotonic() - queuetime
232       if age > self._max_queue_time:
233         # strip old packets off the front
234         self._log(DBG.QUEUE, 'dropping (old)', d=packet)
235         self._pq.popleft()
236         continue
237
238       self._log(DBG.QUEUE, 'nonempty ? nonempty.')
239       return True
240
241   def process(self, sizequery, moredata, max_batch):
242     # sizequery() should return size of batch so far
243     # moredata(s) should add s to batch
244     self._log(DBG.QUEUE, 'process...')
245     while True:
246       try: (dummy, packet) = self._pq[0]
247       except IndexError:
248         self._log(DBG.QUEUE, 'process... empty')
249         break
250
251       self._log(DBG.QUEUE_CTRL, 'process... packet', d=packet)
252
253       encoded = slip.encode(packet)
254       sofar = sizequery()  
255
256       self._log(DBG.QUEUE_CTRL,
257                 'process... (sofar=%d, max=%d) encoded' % (sofar, max_batch),
258                 d=encoded)
259
260       if sofar > 0:
261         if sofar + len(slip.delimiter) + len(encoded) > max_batch:
262           self._log(DBG.QUEUE_CTRL, 'process... overflow')
263           break
264         moredata(slip.delimiter)
265
266       moredata(encoded)
267       self._pq.popleft()
268
269 #---------- error handling ----------
270
271 _crashing = False
272
273 def crash(err):
274   global _crashing
275   _crashing = True
276   print('CRASH ', err, file=sys.stderr)
277   try: reactor.stop()
278   except twisted.internet.error.ReactorNotRunning: pass
279
280 def crash_on_defer(defer):
281   defer.addErrback(lambda err: crash(err))
282
283 def crash_on_critical(event):
284   if event.get('log_level') >= LogLevel.critical:
285     crash(twisted.logger.formatEvent(event))
286
287 #---------- config processing ----------
288
289 def process_cfg_common_always():
290   global mtu
291   c.mtu = cfg.get('virtual','mtu')
292
293 def process_cfg_ipif(section, varmap):
294   for d, s in varmap:
295     try: v = getattr(c, s)
296     except AttributeError: continue
297     setattr(c, d, v)
298
299   #print(repr(c))
300
301   c.ipif_command = cfg.get(section,'ipif', vars=c.__dict__)
302
303 def process_cfg_network():
304   c.network = ipnetwork(cfg.get('virtual','network'))
305   if c.network.num_addresses < 3 + 2:
306     raise ValueError('network needs at least 2^3 addresses')
307
308 def process_cfg_server():
309   try:
310     c.server = cfg.get('virtual','server')
311   except NoOptionError:
312     process_cfg_network()
313     c.server = next(c.network.hosts())
314
315 class ServerAddr():
316   def __init__(self, port, addrspec):
317     self.port = port
318     # also self.addr
319     try:
320       self.addr = ipaddress.IPv4Address(addrspec)
321       self._endpointfactory = twisted.internet.endpoints.TCP4ServerEndpoint
322       self._inurl = b'%s'
323     except AddressValueError:
324       self.addr = ipaddress.IPv6Address(addrspec)
325       self._endpointfactory = twisted.internet.endpoints.TCP6ServerEndpoint
326       self._inurl = b'[%s]'
327   def make_endpoint(self):
328     return self._endpointfactory(reactor, self.port, self.addr)
329   def url(self):
330     url = b'http://' + (self._inurl % str(self.addr).encode('ascii'))
331     if self.port != 80: url += b':%d' % self.port
332     url += b'/'
333     return url
334     
335 def process_cfg_saddrs():
336   try: port = cfg.getint('server','port')
337   except NoOptionError: port = 80
338
339   c.saddrs = [ ]
340   for addrspec in cfg.get('server','addrs').split():
341     sa = ServerAddr(port, addrspec)
342     c.saddrs.append(sa)
343
344 def process_cfg_clients(constructor):
345   c.clients = [ ]
346   for cs in cfg.sections():
347     if not (':' in cs or '.' in cs): continue
348     ci = ipaddr(cs)
349     pw = cfg.get(cs, 'password')
350     pw = pw.encode('utf-8')
351     constructor(ci,cs,pw)
352
353 #---------- startup ----------
354
355 def common_startup():
356   log_formatter = twisted.logger.formatEventAsClassicLogText
357   log_observer = twisted.logger.FileLogObserver(sys.stderr, log_formatter)
358   twisted.logger.globalLogBeginner.beginLoggingTo(
359     [ log_observer, crash_on_critical ]
360     )
361
362   optparser.add_option('-c', '--config', dest='configfile',
363                        default='/etc/hippotat/config')
364   (opts, args) = optparser.parse_args()
365   if len(args): optparser.error('no non-option arguments please')
366
367   re = regexp.compile('#.*')
368   cfg.read_string(re.sub('', defcfg))
369   cfg.read(opts.configfile)
370
371 def common_run():
372   log_debug(DBG.INIT, 'entering reactor')
373   if not _crashing: reactor.run()
374   print('CRASHED (end)', file=sys.stderr)