chiark / gitweb /
wip, fixes
[hippotat.git] / hippotat / __init__.py
1 # -*- python -*-
2
3 import signal
4 signal.signal(signal.SIGINT, signal.SIG_DFL)
5
6 import sys
7
8 import twisted
9 from twisted.internet import reactor
10 import twisted.internet.endpoints
11 import twisted.logger
12 from twisted.logger import LogLevel
13 import twisted.python.constants
14 from twisted.python.constants import NamedConstant
15
16 import ipaddress
17 from ipaddress import AddressValueError
18
19 from optparse import OptionParser
20 from configparser import ConfigParser
21 from configparser import NoOptionError
22
23 import collections
24 import time
25 import codecs
26 import traceback
27
28 import re as regexp
29
30 import hippotat.slip as slip
31
32 class DBG(twisted.python.constants.Names):
33   ROUTE = NamedConstant()
34   DROP = NamedConstant()
35   FLOW = NamedConstant()
36   HTTP = NamedConstant()
37   HTTP_CTRL = NamedConstant()
38   INIT = NamedConstant()
39   QUEUE = NamedConstant()
40   QUEUE_CTRL = NamedConstant()
41   HTTP_FULL = NamedConstant()
42
43 _hex_codec = codecs.getencoder('hex_codec')
44
45 log = twisted.logger.Logger()
46
47 def log_debug(dflag, msg, idof=None, d=None):
48   #print('---------------->',repr((dflag, msg, idof, d)), file=sys.stderr)
49   if idof is not None:
50     msg = '[%d] %s' % (id(idof), msg)
51   if d is not None:
52     #d = d[0:64]
53     d = _hex_codec(d)[0].decode('ascii')
54     msg += ' ' + d
55   log.info('{dflag} {msgcore}', dflag=dflag, msgcore=msg)
56
57 defcfg = '''
58 [DEFAULT]
59 #[<client>] overrides
60 max_batch_down = 65536           # used by server, subject to [limits]
61 max_queue_time = 10              # used by server, subject to [limits]
62 max_request_time = 54            # used by server, subject to [limits]
63 target_requests_outstanding = 3  # must match; subject to [limits] on server
64 max_requests_outstanding = 4     # used by client
65 max_batch_up = 4000              # used by client
66 http_timeout = 30                # used by client
67 http_retry = 5                   # used by client
68
69 #[server] or [<client>] overrides
70 ipif = userv root ipif %(local)s,%(peer)s,%(mtu)s,slip %(rnets)s
71 # extra interpolations:  %(local)s        %(peer)s          %(rnet)s
72 #  obtained   on server  [virtual]server  [virtual]relay    [virtual]network
73 #      from   on client  <client>         [virtual]server   [virtual]routes
74
75 [virtual]
76 mtu = 1500
77 routes = ''
78 # network = <prefix>/<len>  # mandatory for server
79 # server  = <ipaddr>   # used by both, default is computed from `network'
80 # relay   = <ipaddr>   # used by server, default from `network' and `server'
81 #  default server is first host in network
82 #  default relay is first host which is not server
83
84 [server]
85 # addrs = 127.0.0.1 ::1    # mandatory for server
86 port = 80                  # used by server
87 # url              # used by client; default from first `addrs' and `port'
88
89 # [<client-ip4-or-ipv6-address>]
90 # password = <password>    # used by both, must match
91
92 [limits]
93 max_batch_down = 262144           # used by server
94 max_queue_time = 121              # used by server
95 max_request_time = 121            # used by server
96 target_requests_outstanding = 10  # used by server
97 '''
98
99 # these need to be defined here so that they can be imported by import *
100 cfg = ConfigParser()
101 optparser = OptionParser()
102
103 _mimetrans = bytes.maketrans(b'-'+slip.esc, slip.esc+b'-')
104 def mime_translate(s):
105   # SLIP-encoded packets cannot contain ESC ESC.
106   # Swap `-' and ESC.  The result cannot contain `--'
107   return s.translate(_mimetrans)
108
109 class ConfigResults:
110   def __init__(self, d = { }):
111     self.__dict__ = d
112   def __repr__(self):
113     return 'ConfigResults('+repr(self.__dict__)+')'
114
115 c = ConfigResults()
116
117 def log_discard(packet, iface, saddr, daddr, why):
118   log_debug(DBG.DROP,
119             'discarded packet [%s] %s -> %s: %s' % (iface, saddr, daddr, why),
120             d=packet)
121
122 #---------- packet parsing ----------
123
124 def packet_addrs(packet):
125   version = packet[0] >> 4
126   if version == 4:
127     addrlen = 4
128     saddroff = 3*4
129     factory = ipaddress.IPv4Address
130   elif version == 6:
131     addrlen = 16
132     saddroff = 2*4
133     factory = ipaddress.IPv6Address
134   else:
135     raise ValueError('unsupported IP version %d' % version)
136   saddr = factory(packet[ saddroff           : saddroff + addrlen   ])
137   daddr = factory(packet[ saddroff + addrlen : saddroff + addrlen*2 ])
138   return (saddr, daddr)
139
140 #---------- address handling ----------
141
142 def ipaddr(input):
143   try:
144     r = ipaddress.IPv4Address(input)
145   except AddressValueError:
146     r = ipaddress.IPv6Address(input)
147   return r
148
149 def ipnetwork(input):
150   try:
151     r = ipaddress.IPv4Network(input)
152   except NetworkValueError:
153     r = ipaddress.IPv6Network(input)
154   return r
155
156 #---------- ipif (SLIP) subprocess ----------
157
158 class SlipStreamDecoder():
159   def __init__(self, on_packet):
160     # we will call packet(<packet>)
161     self._buffer = b''
162     self._on_packet = on_packet
163
164   def inputdata(self, data):
165     #print('SLIP-GOT ', repr(data))
166     data = self._buffer + data
167     self._buffer = b''
168     packets = slip.decode(data)
169     self._buffer = packets.pop()
170     for packet in packets:
171       self._maybe_packet(packet)
172
173   def _maybe_packet(self, packet):
174       if len(packet):
175         self._on_packet(packet)
176
177   def flush(self):
178     self._maybe_packet(self._buffer)
179     self._buffer = b''
180
181 class _IpifProcessProtocol(twisted.internet.protocol.ProcessProtocol):
182   def __init__(self, router):
183     self._router = router
184     self._decoder = SlipStreamDecoder(self.slip_on_packet)
185   def connectionMade(self): pass
186   def outReceived(self, data):
187     self._decoder.inputdata(data)
188   def slip_on_packet(self, packet):
189     (saddr, daddr) = packet_addrs(packet)
190     if saddr.is_link_local or daddr.is_link_local:
191       log_discard(packet, 'ipif', saddr, daddr, 'link-local')
192       return
193     self._router(packet, saddr, daddr)
194   def processEnded(self, status):
195     status.raiseException()
196
197 def start_ipif(command, router):
198   global ipif
199   ipif = _IpifProcessProtocol(router)
200   reactor.spawnProcess(ipif,
201                        '/bin/sh',['sh','-xc', command],
202                        childFDs={0:'w', 1:'r', 2:2},
203                        env=None)
204
205 def queue_inbound(packet):
206   log_debug(DBG.FLOW, "queue_inbound", d=packet)
207   ipif.transport.write(slip.delimiter)
208   ipif.transport.write(slip.encode(packet))
209   ipif.transport.write(slip.delimiter)
210
211 #---------- packet queue ----------
212
213 class PacketQueue():
214   def __init__(self, desc, max_queue_time):
215     self._desc = desc
216     assert(desc + '')
217     self._max_queue_time = max_queue_time
218     self._pq = collections.deque() # packets
219
220   def _log(self, dflag, msg, **kwargs):
221     log_debug(dflag, self._desc+' pq: '+msg, **kwargs)
222
223   def append(self, packet):
224     self._log(DBG.QUEUE, 'append', d=packet)
225     self._pq.append((time.monotonic(), packet))
226
227   def nonempty(self):
228     self._log(DBG.QUEUE, 'nonempty ?')
229     while True:
230       try: (queuetime, packet) = self._pq[0]
231       except IndexError:
232         self._log(DBG.QUEUE, 'nonempty ? empty.')
233         return False
234
235       age = time.monotonic() - queuetime
236       if age > self._max_queue_time:
237         # strip old packets off the front
238         self._log(DBG.QUEUE, 'dropping (old)', d=packet)
239         self._pq.popleft()
240         continue
241
242       self._log(DBG.QUEUE, 'nonempty ? nonempty.')
243       return True
244
245   def process(self, sizequery, moredata, max_batch):
246     # sizequery() should return size of batch so far
247     # moredata(s) should add s to batch
248     self._log(DBG.QUEUE, 'process...')
249     while True:
250       try: (dummy, packet) = self._pq[0]
251       except IndexError:
252         self._log(DBG.QUEUE, 'process... empty')
253         break
254
255       self._log(DBG.QUEUE_CTRL, 'process... packet', d=packet)
256
257       encoded = slip.encode(packet)
258       sofar = sizequery()  
259
260       self._log(DBG.QUEUE_CTRL,
261                 'process... (sofar=%d, max=%d) encoded' % (sofar, max_batch),
262                 d=encoded)
263
264       if sofar > 0:
265         if sofar + len(slip.delimiter) + len(encoded) > max_batch:
266           self._log(DBG.QUEUE_CTRL, 'process... overflow')
267           break
268         moredata(slip.delimiter)
269
270       moredata(encoded)
271       self._pq.popleft()
272
273 #---------- error handling ----------
274
275 _crashing = False
276
277 def crash(err):
278   global _crashing
279   _crashing = True
280   print('CRASH ', err, file=sys.stderr)
281   try: reactor.stop()
282   except twisted.internet.error.ReactorNotRunning: pass
283
284 def crash_on_defer(defer):
285   defer.addErrback(lambda err: crash(err))
286
287 def crash_on_critical(event):
288   if event.get('log_level') >= LogLevel.critical:
289     crash(twisted.logger.formatEvent(event))
290
291 #---------- config processing ----------
292
293 def process_cfg_common_always():
294   global mtu
295   c.mtu = cfg.get('virtual','mtu')
296
297 def process_cfg_ipif(section, varmap):
298   for d, s in varmap:
299     try: v = getattr(c, s)
300     except AttributeError: continue
301     setattr(c, d, v)
302
303   #print(repr(c))
304
305   c.ipif_command = cfg.get(section,'ipif', vars=c.__dict__)
306
307 def process_cfg_network():
308   c.network = ipnetwork(cfg.get('virtual','network'))
309   if c.network.num_addresses < 3 + 2:
310     raise ValueError('network needs at least 2^3 addresses')
311
312 def process_cfg_server():
313   try:
314     c.server = cfg.get('virtual','server')
315   except NoOptionError:
316     process_cfg_network()
317     c.server = next(c.network.hosts())
318
319 class ServerAddr():
320   def __init__(self, port, addrspec):
321     self.port = port
322     # also self.addr
323     try:
324       self.addr = ipaddress.IPv4Address(addrspec)
325       self._endpointfactory = twisted.internet.endpoints.TCP4ServerEndpoint
326       self._inurl = b'%s'
327     except AddressValueError:
328       self.addr = ipaddress.IPv6Address(addrspec)
329       self._endpointfactory = twisted.internet.endpoints.TCP6ServerEndpoint
330       self._inurl = b'[%s]'
331   def make_endpoint(self):
332     return self._endpointfactory(reactor, self.port, self.addr)
333   def url(self):
334     url = b'http://' + (self._inurl % str(self.addr).encode('ascii'))
335     if self.port != 80: url += b':%d' % self.port
336     url += b'/'
337     return url
338     
339 def process_cfg_saddrs():
340   try: port = cfg.getint('server','port')
341   except NoOptionError: port = 80
342
343   c.saddrs = [ ]
344   for addrspec in cfg.get('server','addrs').split():
345     sa = ServerAddr(port, addrspec)
346     c.saddrs.append(sa)
347
348 def process_cfg_clients(constructor):
349   c.clients = [ ]
350   for cs in cfg.sections():
351     if not (':' in cs or '.' in cs): continue
352     ci = ipaddr(cs)
353     pw = cfg.get(cs, 'password')
354     pw = pw.encode('utf-8')
355     constructor(ci,cs,pw)
356
357 #---------- startup ----------
358
359 def common_startup():
360   log_formatter = twisted.logger.formatEventAsClassicLogText
361   log_observer = twisted.logger.FileLogObserver(sys.stderr, log_formatter)
362   twisted.logger.globalLogBeginner.beginLoggingTo(
363     [ log_observer, crash_on_critical ]
364     )
365
366   optparser.add_option('-c', '--config', dest='configfile',
367                        default='/etc/hippotat/config')
368   (opts, args) = optparser.parse_args()
369   if len(args): optparser.error('no non-option arguments please')
370
371   re = regexp.compile('#.*')
372   cfg.read_string(re.sub('', defcfg))
373   cfg.read(opts.configfile)
374
375 def common_run():
376   log_debug(DBG.INIT, 'entering reactor')
377   if not _crashing: reactor.run()
378   print('CRASHED (end)', file=sys.stderr)