chiark / gitweb /
wip, fixes
[hippotat.git] / hippotat / __init__.py
1 # -*- python -*-
2
3 import signal
4 signal.signal(signal.SIGINT, signal.SIG_DFL)
5
6 import sys
7
8 import twisted
9 from twisted.internet import reactor
10 import twisted.internet.endpoints
11 import twisted.logger
12 from twisted.logger import LogLevel
13 import twisted.python.constants
14 from twisted.python.constants import NamedConstant
15
16 import ipaddress
17 from ipaddress import AddressValueError
18
19 from optparse import OptionParser
20 from configparser import ConfigParser
21 from configparser import NoOptionError
22
23 import collections
24 import time
25 import codecs
26 import traceback
27
28 import re as regexp
29
30 import hippotat.slip as slip
31
32 class DBG(twisted.python.constants.Names):
33   ROUTE = NamedConstant()
34   DROP = NamedConstant()
35   FLOW = NamedConstant()
36   HTTP = NamedConstant()
37   HTTP_CTRL = NamedConstant()
38   INIT = NamedConstant()
39   QUEUE = NamedConstant()
40   QUEUE_CTRL = NamedConstant()
41   HTTP_FULL = NamedConstant()
42
43 _hex_codec = codecs.getencoder('hex_codec')
44
45 log = twisted.logger.Logger()
46
47 def log_debug(dflag, msg, idof=None, d=None):
48   #print('---------------->',repr((dflag, msg, idof, d)), file=sys.stderr)
49   if idof is not None:
50     msg = '[%d] %s' % (id(idof), msg)
51   if d is not None:
52     #d = d[0:64]
53     d = _hex_codec(d)[0].decode('ascii')
54     msg += ' ' + d
55   log.info('{dflag} {msgcore}', dflag=dflag, msgcore=msg)
56
57 defcfg = '''
58 [DEFAULT]
59 #[<client>] overrides
60 max_batch_down = 65536           # used by server, subject to [limits]
61 max_queue_time = 10              # used by server, subject to [limits]
62 max_request_time = 54            # used by server, subject to [limits]
63 target_requests_outstanding = 3  # must match; subject to [limits] on server
64 max_requests_outstanding = 4     # used by client
65 max_batch_up = 4000              # used by client
66 http_timeout = 30                # used by client
67 http_retry = 5                   # used by client
68
69 #[server] or [<client>] overrides
70 ipif = userv root ipif %(local)s,%(peer)s,%(mtu)s,slip %(rnets)s
71 # extra interpolations:  %(local)s        %(peer)s          %(rnet)s
72 #  obtained   on server  [virtual]server  [virtual]relay    [virtual]network
73 #      from   on client  <client>         [virtual]server   [virtual]routes
74
75 [virtual]
76 mtu = 1500
77 routes = ''
78 # network = <prefix>/<len>  # mandatory for server
79 # server  = <ipaddr>   # used by both, default is computed from `network'
80 # relay   = <ipaddr>   # used by server, default from `network' and `server'
81 #  default server is first host in network
82 #  default relay is first host which is not server
83
84 [server]
85 # addrs = 127.0.0.1 ::1    # mandatory for server
86 port = 80                  # used by server
87 # url              # used by client; default from first `addrs' and `port'
88
89 # [<client-ip4-or-ipv6-address>]
90 # password = <password>    # used by both, must match
91
92 [limits]
93 max_batch_down = 262144           # used by server
94 max_queue_time = 121              # used by server
95 max_request_time = 121            # used by server
96 target_requests_outstanding = 10  # used by server
97 '''
98
99 # these need to be defined here so that they can be imported by import *
100 cfg = ConfigParser()
101 optparser = OptionParser()
102
103 _mimetrans = bytes.maketrans(b'-'+slip.esc, slip.esc+b'-')
104 def mime_translate(s):
105   # SLIP-encoded packets cannot contain ESC ESC.
106   # Swap `-' and ESC.  The result cannot contain `--'
107   return s.translate(_mimetrans)
108
109 class ConfigResults:
110   def __init__(self, d = { }):
111     self.__dict__ = d
112   def __repr__(self):
113     return 'ConfigResults('+repr(self.__dict__)+')'
114
115 c = ConfigResults()
116
117 def log_discard(packet, iface, saddr, daddr, why):
118   log_debug(DBG.DROP,
119             'discarded packet [%s] %s -> %s: %s' % (iface, saddr, daddr, why),
120             d=packet)
121
122 #---------- packet parsing ----------
123
124 def packet_addrs(packet):
125   version = packet[0] >> 4
126   if version == 4:
127     addrlen = 4
128     saddroff = 3*4
129     factory = ipaddress.IPv4Address
130   elif version == 6:
131     addrlen = 16
132     saddroff = 2*4
133     factory = ipaddress.IPv6Address
134   else:
135     raise ValueError('unsupported IP version %d' % version)
136   saddr = factory(packet[ saddroff           : saddroff + addrlen   ])
137   daddr = factory(packet[ saddroff + addrlen : saddroff + addrlen*2 ])
138   return (saddr, daddr)
139
140 #---------- address handling ----------
141
142 def ipaddr(input):
143   try:
144     r = ipaddress.IPv4Address(input)
145   except AddressValueError:
146     r = ipaddress.IPv6Address(input)
147   return r
148
149 def ipnetwork(input):
150   try:
151     r = ipaddress.IPv4Network(input)
152   except NetworkValueError:
153     r = ipaddress.IPv6Network(input)
154   return r
155
156 #---------- ipif (SLIP) subprocess ----------
157
158 class SlipStreamDecoder():
159   def __init__(self, on_packet):
160     # we will call packet(<packet>)
161     self._buffer = b''
162     self._on_packet = on_packet
163
164   def inputdata(self, data):
165     #print('SLIP-GOT ', repr(data))
166     self._buffer += data
167     packets = slip.decode(self._buffer)
168     self._buffer = packets.pop()
169     for packet in packets:
170       self._maybe_packet(packet)
171
172   def _maybe_packet(self, packet):
173       if len(packet):
174         self._on_packet(packet)
175
176   def flush(self):
177     self._maybe_packet(self._buffer)
178     self._buffer = b''
179
180 class _IpifProcessProtocol(twisted.internet.protocol.ProcessProtocol):
181   def __init__(self, router):
182     self._router = router
183     self._decoder = SlipStreamDecoder(self.slip_on_packet)
184   def connectionMade(self): pass
185   def outReceived(self, data):
186     self._decoder.inputdata(data)
187   def slip_on_packet(self, packet):
188     (saddr, daddr) = packet_addrs(packet)
189     if saddr.is_link_local or daddr.is_link_local:
190       log_discard(packet, 'ipif', saddr, daddr, 'link-local')
191       return
192     self._router(packet, saddr, daddr)
193   def processEnded(self, status):
194     status.raiseException()
195
196 def start_ipif(command, router):
197   global ipif
198   ipif = _IpifProcessProtocol(router)
199   reactor.spawnProcess(ipif,
200                        '/bin/sh',['sh','-xc', command],
201                        childFDs={0:'w', 1:'r', 2:2},
202                        env=None)
203
204 def queue_inbound(packet):
205   log_debug(DBG.FLOW, "queue_inbound", d=packet)
206   ipif.transport.write(slip.delimiter)
207   ipif.transport.write(slip.encode(packet))
208   ipif.transport.write(slip.delimiter)
209
210 #---------- packet queue ----------
211
212 class PacketQueue():
213   def __init__(self, desc, max_queue_time):
214     self._desc = desc
215     assert(desc + '')
216     self._max_queue_time = max_queue_time
217     self._pq = collections.deque() # packets
218
219   def _log(self, dflag, msg, **kwargs):
220     log_debug(dflag, self._desc+' pq: '+msg, **kwargs)
221
222   def append(self, packet):
223     self._log(DBG.QUEUE, 'append', d=packet)
224     self._pq.append((time.monotonic(), packet))
225
226   def nonempty(self):
227     self._log(DBG.QUEUE, 'nonempty ?')
228     while True:
229       try: (queuetime, packet) = self._pq[0]
230       except IndexError:
231         self._log(DBG.QUEUE, 'nonempty ? empty.')
232         return False
233
234       age = time.monotonic() - queuetime
235       if age > self._max_queue_time:
236         # strip old packets off the front
237         self._log(DBG.QUEUE, 'dropping (old)', d=packet)
238         self._pq.popleft()
239         continue
240
241       self._log(DBG.QUEUE, 'nonempty ? nonempty.')
242       return True
243
244   def process(self, sizequery, moredata, max_batch):
245     # sizequery() should return size of batch so far
246     # moredata(s) should add s to batch
247     self._log(DBG.QUEUE, 'process...')
248     while True:
249       try: (dummy, packet) = self._pq[0]
250       except IndexError:
251         self._log(DBG.QUEUE, 'process... empty')
252         break
253
254       self._log(DBG.QUEUE_CTRL, 'process... packet', d=packet)
255
256       encoded = slip.encode(packet)
257       sofar = sizequery()  
258
259       self._log(DBG.QUEUE_CTRL,
260                 'process... (sofar=%d, max=%d) encoded' % (sofar, max_batch),
261                 d=encoded)
262
263       if sofar > 0:
264         if sofar + len(slip.delimiter) + len(encoded) > max_batch:
265           self._log(DBG.QUEUE_CTRL, 'process... overflow')
266           break
267         moredata(slip.delimiter)
268
269       moredata(encoded)
270       self._pq.popleft()
271
272 #---------- error handling ----------
273
274 _crashing = False
275
276 def crash(err):
277   global _crashing
278   _crashing = True
279   print('CRASH ', err, file=sys.stderr)
280   try: reactor.stop()
281   except twisted.internet.error.ReactorNotRunning: pass
282
283 def crash_on_defer(defer):
284   defer.addErrback(lambda err: crash(err))
285
286 def crash_on_critical(event):
287   if event.get('log_level') >= LogLevel.critical:
288     crash(twisted.logger.formatEvent(event))
289
290 #---------- config processing ----------
291
292 def process_cfg_common_always():
293   global mtu
294   c.mtu = cfg.get('virtual','mtu')
295
296 def process_cfg_ipif(section, varmap):
297   for d, s in varmap:
298     try: v = getattr(c, s)
299     except AttributeError: continue
300     setattr(c, d, v)
301
302   #print(repr(c))
303
304   c.ipif_command = cfg.get(section,'ipif', vars=c.__dict__)
305
306 def process_cfg_network():
307   c.network = ipnetwork(cfg.get('virtual','network'))
308   if c.network.num_addresses < 3 + 2:
309     raise ValueError('network needs at least 2^3 addresses')
310
311 def process_cfg_server():
312   try:
313     c.server = cfg.get('virtual','server')
314   except NoOptionError:
315     process_cfg_network()
316     c.server = next(c.network.hosts())
317
318 class ServerAddr():
319   def __init__(self, port, addrspec):
320     self.port = port
321     # also self.addr
322     try:
323       self.addr = ipaddress.IPv4Address(addrspec)
324       self._endpointfactory = twisted.internet.endpoints.TCP4ServerEndpoint
325       self._inurl = b'%s'
326     except AddressValueError:
327       self.addr = ipaddress.IPv6Address(addrspec)
328       self._endpointfactory = twisted.internet.endpoints.TCP6ServerEndpoint
329       self._inurl = b'[%s]'
330   def make_endpoint(self):
331     return self._endpointfactory(reactor, self.port, self.addr)
332   def url(self):
333     url = b'http://' + (self._inurl % str(self.addr).encode('ascii'))
334     if self.port != 80: url += b':%d' % self.port
335     url += b'/'
336     return url
337     
338 def process_cfg_saddrs():
339   try: port = cfg.getint('server','port')
340   except NoOptionError: port = 80
341
342   c.saddrs = [ ]
343   for addrspec in cfg.get('server','addrs').split():
344     sa = ServerAddr(port, addrspec)
345     c.saddrs.append(sa)
346
347 def process_cfg_clients(constructor):
348   c.clients = [ ]
349   for cs in cfg.sections():
350     if not (':' in cs or '.' in cs): continue
351     ci = ipaddr(cs)
352     pw = cfg.get(cs, 'password')
353     pw = pw.encode('utf-8')
354     constructor(ci,cs,pw)
355
356 #---------- startup ----------
357
358 def common_startup():
359   log_formatter = twisted.logger.formatEventAsClassicLogText
360   log_observer = twisted.logger.FileLogObserver(sys.stderr, log_formatter)
361   twisted.logger.globalLogBeginner.beginLoggingTo(
362     [ log_observer, crash_on_critical ]
363     )
364
365   optparser.add_option('-c', '--config', dest='configfile',
366                        default='/etc/hippotat/config')
367   (opts, args) = optparser.parse_args()
368   if len(args): optparser.error('no non-option arguments please')
369
370   re = regexp.compile('#.*')
371   cfg.read_string(re.sub('', defcfg))
372   cfg.read(opts.configfile)
373
374 def common_run():
375   log_debug(DBG.INIT, 'entering reactor')
376   if not _crashing: reactor.run()
377   print('CRASHED (end)', file=sys.stderr)