chiark / gitweb /
fixes
[hippotat.git] / hippotat / __init__.py
1 # -*- python -*-
2
3 import signal
4 signal.signal(signal.SIGINT, signal.SIG_DFL)
5
6 import sys
7
8 import twisted
9 from twisted.internet import reactor
10 from twisted.logger import LogLevel
11 import twisted.internet.endpoints
12
13 import ipaddress
14 from ipaddress import AddressValueError
15
16 from optparse import OptionParser
17 from configparser import ConfigParser
18 from configparser import NoOptionError
19
20 import collections
21
22 import re as regexp
23
24 import hippotat.slip as slip
25
26 defcfg = '''
27 [DEFAULT]
28 #[<client>] overrides
29 max_batch_down = 65536           # used by server, subject to [limits]
30 max_queue_time = 10              # used by server, subject to [limits]
31 max_request_time = 54            # used by server, subject to [limits]
32 target_requests_outstanding = 3  # must match; subject to [limits] on server
33 max_requests_outstanding = 4     # used by client
34 max_batch_up = 4000              # used by client
35 http_timeout = 30                # used by client
36
37 #[server] or [<client>] overrides
38 ipif = userv root ipif %(local)s,%(peer)s,%(mtu)s,slip %(rnets)s
39 # extra interpolations:  %(local)s        %(peer)s          %(rnet)s
40 #  obtained   on server  [virtual]server  [virtual]relay    [virtual]network
41 #      from   on client  <client>         [virtual]server   [virtual]routes
42
43 [virtual]
44 mtu = 1500
45 routes = ''
46 # network = <prefix>/<len>  # mandatory for server
47 # server  = <ipaddr>   # used by both, default is computed from `network'
48 # relay   = <ipaddr>   # used by server, default from `network' and `server'
49 #  default server is first host in network
50 #  default relay is first host which is not server
51
52 [server]
53 # addrs = 127.0.0.1 ::1    # mandatory for server
54 port = 80                  # used by server
55 # url              # used by client; default from first `addrs' and `port'
56
57 # [<client-ip4-or-ipv6-address>]
58 # password = <password>    # used by both, must match
59
60 [limits]
61 max_batch_down = 262144           # used by server
62 max_queue_time = 121              # used by server
63 max_request_time = 121            # used by server
64 target_requests_outstanding = 10  # used by server
65 '''
66
67 # these need to be defined here so that they can be imported by import *
68 cfg = ConfigParser()
69 optparser = OptionParser()
70
71 _mimetrans = bytes.maketrans(b'-'+slip.esc, slip.esc+b'-')
72 def mime_translate(s):
73   # SLIP-encoded packets cannot contain ESC ESC.
74   # Swap `-' and ESC.  The result cannot contain `--'
75   return s.translate(_mimetrans)
76
77 class ConfigResults:
78   def __init__(self, d = { }):
79     self.__dict__ = d
80   def __repr__(self):
81     return 'ConfigResults('+repr(self.__dict__)+')'
82
83 c = ConfigResults()
84
85 def log_discard(packet, saddr, daddr, why):
86   print('DROP ', saddr, daddr, why)
87 #  syslog.syslog(syslog.LOG_DEBUG,
88 #                'discarded packet %s -> %s (%s)' % (saddr, daddr, why))
89
90 #---------- packet parsing ----------
91
92 def packet_addrs(packet):
93   version = packet[0] >> 4
94   if version == 4:
95     addrlen = 4
96     saddroff = 3*4
97     factory = ipaddress.IPv4Address
98   elif version == 6:
99     addrlen = 16
100     saddroff = 2*4
101     factory = ipaddress.IPv6Address
102   else:
103     raise ValueError('unsupported IP version %d' % version)
104   saddr = factory(packet[ saddroff           : saddroff + addrlen   ])
105   daddr = factory(packet[ saddroff + addrlen : saddroff + addrlen*2 ])
106   return (saddr, daddr)
107
108 #---------- address handling ----------
109
110 def ipaddr(input):
111   try:
112     r = ipaddress.IPv4Address(input)
113   except AddressValueError:
114     r = ipaddress.IPv6Address(input)
115   return r
116
117 def ipnetwork(input):
118   try:
119     r = ipaddress.IPv4Network(input)
120   except NetworkValueError:
121     r = ipaddress.IPv6Network(input)
122   return r
123
124 #---------- ipif (SLIP) subprocess ----------
125
126 class SlipStreamDecoder():
127   def __init__(self, on_packet):
128     # we will call packet(<packet>)
129     self._buffer = b''
130     self._on_packet = on_packet
131
132   def inputdata(self, data):
133     #print('SLIP-GOT ', repr(data))
134     self._buffer += data
135     packets = slip.decode(self._buffer)
136     self._buffer = packets.pop()
137     for packet in packets:
138       self._maybe_packet(packet)
139
140   def _maybe_packet(self, packet):
141       if len(packet):
142         self._on_packet(packet)
143
144   def flush(self):
145     self._maybe_packet(self._buffer)
146     self._buffer = b''
147
148 class _IpifProcessProtocol(twisted.internet.protocol.ProcessProtocol):
149   def __init__(self, router):
150     self._router = router
151     self._decoder = SlipStreamDecoder(self.slip_on_packet)
152   def connectionMade(self): pass
153   def outReceived(self, data):
154     self._decoder.inputdata(data)
155   def slip_on_packet(self, packet):
156     (saddr, daddr) = packet_addrs(packet)
157     if saddr.is_link_local or daddr.is_link_local:
158       log_discard(packet, saddr, daddr, 'link-local')
159       return
160     self._router(packet, saddr, daddr)
161   def processEnded(self, status):
162     status.raiseException()
163
164 def start_ipif(command, router):
165   global ipif
166   ipif = _IpifProcessProtocol(router)
167   reactor.spawnProcess(ipif,
168                        '/bin/sh',['sh','-xc', command],
169                        childFDs={0:'w', 1:'r', 2:2},
170                        env=None)
171
172 def queue_inbound(packet):
173   ipif.transport.write(slip.delimiter)
174   ipif.transport.write(slip.encode(packet))
175   ipif.transport.write(slip.delimiter)
176
177 #---------- packet queue ----------
178
179 class PacketQueue():
180   def __init__(self, max_queue_time):
181     self._max_queue_time = max_queue_time
182     self._pq = collections.deque() # packets
183
184   def append(self, packet):
185     self._pq.append((time.monotonic(), packet))
186
187   def nonempty(self):
188     while True:
189       try: (queuetime, packet) = self._pq[0]
190       except IndexError: return False
191
192       age = time.monotonic() - queuetime
193       if age > self.max_queue_time:
194         # strip old packets off the front
195         self._pq.popleft()
196         continue
197
198       return True
199
200   def process(self, sizequery, moredata, max_batch):
201     # sizequery() should return size of batch so far
202     # moredata(s) should add s to batch
203     while True:
204       try: (dummy, packet) = self._pq[0]
205       except IndexError: break
206
207       encoded = slip.encode(packet)
208       sofar = sizequery()  
209
210       if sofar > 0:
211         if sofar + len(slip.delimiter) + len(encoded) > max_batch:
212           break
213         moredata(slip.delimiter)
214
215       moredata(encoded)
216       self._pq.popLeft()
217
218 #---------- error handling ----------
219
220 def crash(err):
221   print('CRASH ', err, file=sys.stderr)
222   try: reactor.stop()
223   except twisted.internet.error.ReactorNotRunning: pass
224
225 def crash_on_defer(defer):
226   defer.addErrback(lambda err: crash(err))
227
228 def crash_on_critical(event):
229   if event.get('log_level') >= LogLevel.critical:
230     crash(twisted.logger.formatEvent(event))
231
232 #---------- config processing ----------
233
234 def process_cfg_common_always():
235   global mtu
236   c.mtu = cfg.get('virtual','mtu')
237
238 def process_cfg_ipif(section, varmap):
239   for d, s in varmap:
240     try: v = getattr(c, s)
241     except AttributeError: continue
242     setattr(c, d, v)
243
244   print(repr(c))
245
246   c.ipif_command = cfg.get(section,'ipif', vars=c.__dict__)
247
248 def process_cfg_network():
249   c.network = ipnetwork(cfg.get('virtual','network'))
250   if c.network.num_addresses < 3 + 2:
251     raise ValueError('network needs at least 2^3 addresses')
252
253 def process_cfg_server():
254   try:
255     c.server = cfg.get('virtual','server')
256   except NoOptionError:
257     process_cfg_network()
258     c.server = next(c.network.hosts())
259
260 class ServerAddr():
261   def __init__(self, port, addrspec):
262     self.port = port
263     # also self.addr
264     try:
265       self.addr = ipaddress.IPv4Address(addrspec)
266       self._endpointfactory = twisted.internet.endpoints.TCP4ServerEndpoint
267       self._inurl = '%s'
268     except AddressValueError:
269       self.addr = ipaddress.IPv6Address(addrspec)
270       self._endpointfactory = twisted.internet.endpoints.TCP6ServerEndpoint
271       self._inurl = '[%s]'
272   def make_endpoint(self):
273     return self._endpointfactory(reactor, self.port, self.addr)
274   def url(self):
275     url = 'http://' + (self._inurl % self.addr)
276     if self.port != 80: url += ':%d' % self.port
277     url += '/'
278     return url
279     
280 def process_cfg_saddrs():
281   try: port = cfg.getint('server','port')
282   except NoOptionError: port = 80
283
284   c.saddrs = [ ]
285   for addrspec in cfg.get('server','addrs').split():
286     sa = ServerAddr(port, addrspec)
287     c.saddrs.append(sa)
288
289 def process_cfg_clients(constructor):
290   c.clients = [ ]
291   for cs in cfg.sections():
292     if not (':' in cs or '.' in cs): continue
293     ci = ipaddr(cs)
294     pw = cfg.get(cs, 'password')
295     constructor(ci,cs,pw)
296
297 #---------- startup ----------
298
299 def common_startup():
300   twisted.logger.globalLogPublisher.addObserver(crash_on_critical)
301
302   optparser.add_option('-c', '--config', dest='configfile',
303                        default='/etc/hippotat/config')
304   (opts, args) = optparser.parse_args()
305   if len(args): optparser.error('no non-option arguments please')
306
307   re = regexp.compile('#.*')
308   cfg.read_string(re.sub('', defcfg))
309   cfg.read(opts.configfile)
310
311 def common_run():
312   reactor.run()
313   print('CRASHED (end)', file=sys.stderr)