chiark / gitweb /
9bfe280e4c8c8d8b5d158a3b382f611034e5cccb
[hippotat.git] / hippotat / __init__.py
1 # -*- python -*-
2
3 import signal
4 signal.signal(signal.SIGINT, signal.SIG_DFL)
5
6 import sys
7
8 import twisted
9 from twisted.internet import reactor
10 from twisted.logger import LogLevel
11 import twisted.internet.endpoints
12
13 import ipaddress
14 from ipaddress import AddressValueError
15
16 from optparse import OptionParser
17 from configparser import ConfigParser
18 from configparser import NoOptionError
19
20 import collections
21 import time
22
23 import re as regexp
24
25 import hippotat.slip as slip
26
27 defcfg = '''
28 [DEFAULT]
29 #[<client>] overrides
30 max_batch_down = 65536           # used by server, subject to [limits]
31 max_queue_time = 10              # used by server, subject to [limits]
32 max_request_time = 54            # used by server, subject to [limits]
33 target_requests_outstanding = 3  # must match; subject to [limits] on server
34 max_requests_outstanding = 4     # used by client
35 max_batch_up = 4000              # used by client
36 http_timeout = 30                # used by client
37 http_retry = 5                   # used by client
38
39 #[server] or [<client>] overrides
40 ipif = userv root ipif %(local)s,%(peer)s,%(mtu)s,slip %(rnets)s
41 # extra interpolations:  %(local)s        %(peer)s          %(rnet)s
42 #  obtained   on server  [virtual]server  [virtual]relay    [virtual]network
43 #      from   on client  <client>         [virtual]server   [virtual]routes
44
45 [virtual]
46 mtu = 1500
47 routes = ''
48 # network = <prefix>/<len>  # mandatory for server
49 # server  = <ipaddr>   # used by both, default is computed from `network'
50 # relay   = <ipaddr>   # used by server, default from `network' and `server'
51 #  default server is first host in network
52 #  default relay is first host which is not server
53
54 [server]
55 # addrs = 127.0.0.1 ::1    # mandatory for server
56 port = 80                  # used by server
57 # url              # used by client; default from first `addrs' and `port'
58
59 # [<client-ip4-or-ipv6-address>]
60 # password = <password>    # used by both, must match
61
62 [limits]
63 max_batch_down = 262144           # used by server
64 max_queue_time = 121              # used by server
65 max_request_time = 121            # used by server
66 target_requests_outstanding = 10  # used by server
67 '''
68
69 # these need to be defined here so that they can be imported by import *
70 cfg = ConfigParser()
71 optparser = OptionParser()
72
73 _mimetrans = bytes.maketrans(b'-'+slip.esc, slip.esc+b'-')
74 def mime_translate(s):
75   # SLIP-encoded packets cannot contain ESC ESC.
76   # Swap `-' and ESC.  The result cannot contain `--'
77   return s.translate(_mimetrans)
78
79 class ConfigResults:
80   def __init__(self, d = { }):
81     self.__dict__ = d
82   def __repr__(self):
83     return 'ConfigResults('+repr(self.__dict__)+')'
84
85 c = ConfigResults()
86
87 def log_discard(packet, saddr, daddr, why):
88   print('DROP ', saddr, daddr, why)
89 #  syslog.syslog(syslog.LOG_DEBUG,
90 #                'discarded packet %s -> %s (%s)' % (saddr, daddr, why))
91
92 #---------- packet parsing ----------
93
94 def packet_addrs(packet):
95   version = packet[0] >> 4
96   if version == 4:
97     addrlen = 4
98     saddroff = 3*4
99     factory = ipaddress.IPv4Address
100   elif version == 6:
101     addrlen = 16
102     saddroff = 2*4
103     factory = ipaddress.IPv6Address
104   else:
105     raise ValueError('unsupported IP version %d' % version)
106   saddr = factory(packet[ saddroff           : saddroff + addrlen   ])
107   daddr = factory(packet[ saddroff + addrlen : saddroff + addrlen*2 ])
108   return (saddr, daddr)
109
110 #---------- address handling ----------
111
112 def ipaddr(input):
113   try:
114     r = ipaddress.IPv4Address(input)
115   except AddressValueError:
116     r = ipaddress.IPv6Address(input)
117   return r
118
119 def ipnetwork(input):
120   try:
121     r = ipaddress.IPv4Network(input)
122   except NetworkValueError:
123     r = ipaddress.IPv6Network(input)
124   return r
125
126 #---------- ipif (SLIP) subprocess ----------
127
128 class SlipStreamDecoder():
129   def __init__(self, on_packet):
130     # we will call packet(<packet>)
131     self._buffer = b''
132     self._on_packet = on_packet
133
134   def inputdata(self, data):
135     #print('SLIP-GOT ', repr(data))
136     self._buffer += data
137     packets = slip.decode(self._buffer)
138     self._buffer = packets.pop()
139     for packet in packets:
140       self._maybe_packet(packet)
141
142   def _maybe_packet(self, packet):
143       if len(packet):
144         self._on_packet(packet)
145
146   def flush(self):
147     self._maybe_packet(self._buffer)
148     self._buffer = b''
149
150 class _IpifProcessProtocol(twisted.internet.protocol.ProcessProtocol):
151   def __init__(self, router):
152     self._router = router
153     self._decoder = SlipStreamDecoder(self.slip_on_packet)
154   def connectionMade(self): pass
155   def outReceived(self, data):
156     self._decoder.inputdata(data)
157   def slip_on_packet(self, packet):
158     (saddr, daddr) = packet_addrs(packet)
159     if saddr.is_link_local or daddr.is_link_local:
160       log_discard(packet, saddr, daddr, 'link-local')
161       return
162     self._router(packet, saddr, daddr)
163   def processEnded(self, status):
164     status.raiseException()
165
166 def start_ipif(command, router):
167   global ipif
168   ipif = _IpifProcessProtocol(router)
169   reactor.spawnProcess(ipif,
170                        '/bin/sh',['sh','-xc', command],
171                        childFDs={0:'w', 1:'r', 2:2},
172                        env=None)
173
174 def queue_inbound(packet):
175   ipif.transport.write(slip.delimiter)
176   ipif.transport.write(slip.encode(packet))
177   ipif.transport.write(slip.delimiter)
178
179 #---------- packet queue ----------
180
181 class PacketQueue():
182   def __init__(self, max_queue_time):
183     self._max_queue_time = max_queue_time
184     self._pq = collections.deque() # packets
185
186   def append(self, packet):
187     self._pq.append((time.monotonic(), packet))
188
189   def nonempty(self):
190     while True:
191       try: (queuetime, packet) = self._pq[0]
192       except IndexError: return False
193
194       age = time.monotonic() - queuetime
195       if age > self._max_queue_time:
196         # strip old packets off the front
197         self._pq.popleft()
198         continue
199
200       return True
201
202   def process(self, sizequery, moredata, max_batch):
203     # sizequery() should return size of batch so far
204     # moredata(s) should add s to batch
205     while True:
206       try: (dummy, packet) = self._pq[0]
207       except IndexError: break
208
209       encoded = slip.encode(packet)
210       sofar = sizequery()  
211
212       if sofar > 0:
213         if sofar + len(slip.delimiter) + len(encoded) > max_batch:
214           break
215         moredata(slip.delimiter)
216
217       moredata(encoded)
218       self._pq.popleft()
219
220 #---------- error handling ----------
221
222 def crash(err):
223   print('CRASH ', err, file=sys.stderr)
224   try: reactor.stop()
225   except twisted.internet.error.ReactorNotRunning: pass
226
227 def crash_on_defer(defer):
228   defer.addErrback(lambda err: crash(err))
229
230 def crash_on_critical(event):
231   if event.get('log_level') >= LogLevel.critical:
232     crash(twisted.logger.formatEvent(event))
233
234 #---------- config processing ----------
235
236 def process_cfg_common_always():
237   global mtu
238   c.mtu = cfg.get('virtual','mtu')
239
240 def process_cfg_ipif(section, varmap):
241   for d, s in varmap:
242     try: v = getattr(c, s)
243     except AttributeError: continue
244     setattr(c, d, v)
245
246   print(repr(c))
247
248   c.ipif_command = cfg.get(section,'ipif', vars=c.__dict__)
249
250 def process_cfg_network():
251   c.network = ipnetwork(cfg.get('virtual','network'))
252   if c.network.num_addresses < 3 + 2:
253     raise ValueError('network needs at least 2^3 addresses')
254
255 def process_cfg_server():
256   try:
257     c.server = cfg.get('virtual','server')
258   except NoOptionError:
259     process_cfg_network()
260     c.server = next(c.network.hosts())
261
262 class ServerAddr():
263   def __init__(self, port, addrspec):
264     self.port = port
265     # also self.addr
266     try:
267       self.addr = ipaddress.IPv4Address(addrspec)
268       self._endpointfactory = twisted.internet.endpoints.TCP4ServerEndpoint
269       self._inurl = b'%s'
270     except AddressValueError:
271       self.addr = ipaddress.IPv6Address(addrspec)
272       self._endpointfactory = twisted.internet.endpoints.TCP6ServerEndpoint
273       self._inurl = b'[%s]'
274   def make_endpoint(self):
275     return self._endpointfactory(reactor, self.port, self.addr)
276   def url(self):
277     url = b'http://' + (self._inurl % str(self.addr).encode('ascii'))
278     if self.port != 80: url += b':%d' % self.port
279     url += b'/'
280     return url
281     
282 def process_cfg_saddrs():
283   try: port = cfg.getint('server','port')
284   except NoOptionError: port = 80
285
286   c.saddrs = [ ]
287   for addrspec in cfg.get('server','addrs').split():
288     sa = ServerAddr(port, addrspec)
289     c.saddrs.append(sa)
290
291 def process_cfg_clients(constructor):
292   c.clients = [ ]
293   for cs in cfg.sections():
294     if not (':' in cs or '.' in cs): continue
295     ci = ipaddr(cs)
296     pw = cfg.get(cs, 'password')
297     pw = pw.encode('utf-8')
298     constructor(ci,cs,pw)
299
300 #---------- startup ----------
301
302 def common_startup():
303   twisted.logger.globalLogPublisher.addObserver(crash_on_critical)
304
305   optparser.add_option('-c', '--config', dest='configfile',
306                        default='/etc/hippotat/config')
307   (opts, args) = optparser.parse_args()
308   if len(args): optparser.error('no non-option arguments please')
309
310   re = regexp.compile('#.*')
311   cfg.read_string(re.sub('', defcfg))
312   cfg.read(opts.configfile)
313
314 def common_run():
315   reactor.run()
316   print('CRASHED (end)', file=sys.stderr)