chiark / gitweb /
wip
[hippotat.git] / server
1 #!/usr/bin/python3
2
3 import signal
4 signal.signal(signal.SIGINT, signal.SIG_DFL)
5
6 import sys
7 import os
8
9 import twisted
10 import twisted.internet
11 import twisted.internet.endpoints
12 from twisted.internet import reactor
13 from twisted.web.server import NOT_DONE_YET
14 from twisted.logger import LogLevel
15
16 import ipaddress
17 from ipaddress import AddressValueError
18
19 #import twisted.web.server import Site
20 #from twisted.web.resource import Resource
21
22 from optparse import OptionParser
23 from configparser import ConfigParser
24 from configparser import NoOptionError
25
26 import collections
27
28 import syslog
29
30 clients = { }
31
32 def ipaddr(input):
33   try:
34     r = ipaddress.IPv4Address(input)
35   except AddressValueError:
36     r = ipaddress.IPv6Address(input)
37   return r
38
39 def ipnetwork(input):
40   try:
41     r = ipaddress.IPv4Network(input)
42   except NetworkValueError:
43     r = ipaddress.IPv6Network(input)
44   return r
45
46 defcfg = '''
47 [DEFAULT]
48 max_batch_down = 65536
49 max_queue_time = 10
50 max_request_time = 54
51
52 [virtual]
53 mtu = 1500
54 # network
55 # [host]
56 # [relay]
57
58 [server]
59 ipif = userv root ipif %(host)s,%(relay)s,%(mtu)s,slip %(network)s
60 addrs = 127.0.0.1 ::1
61 port = 8099
62
63 [limits]
64 max_batch_down = 262144
65 max_queue_time = 121
66 max_request_time = 121
67 '''
68
69 #---------- error handling ----------
70
71 def crash(err):
72   print('CRASH ', err, file=sys.stderr)
73   try: reactor.stop()
74   except twisted.internet.error.ReactorNotRunning: pass
75
76 def crash_on_defer(defer):
77   defer.addErrback(lambda err: crash(err))
78
79 def crash_on_critical(event):
80   if event.get('log_level') >= LogLevel.critical:
81     crash(twisted.logger.formatEvent(event))
82
83 #---------- "router" ----------
84
85 def route(packet, saddr, daddr):
86   print('TRACE ', saddr, daddr, packet)
87   try: client = clients[daddr]
88   except KeyError: dclient = None
89   if dclient is not None:
90     dclient.queue_outbound(packet)
91   elif saddr.is_link_local or daddr.is_link_local:
92     log_discard(packet, saddr, daddr, 'link-local')
93   elif daddr == host or daddr not in network:
94     print('TRACE INBOUND ', saddr, daddr, packet)
95     queue_inbound(packet)
96   elif daddr == relay:
97     log_discard(packet, saddr, daddr, 'relay')
98   else:
99     log_discard(packet, saddr, daddr, 'no client')
100
101 def log_discard(packet, saddr, daddr, why):
102   print('DROP ', saddr, daddr, why)
103 #  syslog.syslog(syslog.LOG_DEBUG,
104 #                'discarded packet %s -> %s (%s)' % (saddr, daddr, why))
105
106 #---------- ipif (slip subprocess) ----------
107
108 class IpifProcessProtocol(twisted.internet.protocol.ProcessProtocol):
109   def __init__(self):
110     self._buffer = b''
111   def connectionMade(self): pass
112   def outReceived(self, data):
113     #print('RECV ', repr(data))
114     self._buffer += data
115     packets = slip_decode(self._buffer)
116     self._buffer = packets.pop()
117     for packet in packets:
118       if not len(packet): continue
119       (saddr, daddr) = packet_addrs(packet)
120       route(packet, saddr, daddr)
121   def processEnded(self, status):
122     status.raiseException()
123
124 def start_ipif():
125   global ipif
126   ipif = IpifProcessProtocol()
127   reactor.spawnProcess(ipif,
128                        '/bin/sh',['sh','-xc', ipif_command],
129                        childFDs={0:'w', 1:'r', 2:2})
130
131 def queue_inbound(packet):
132   ipif.transport.write(slip_delimiter)
133   ipif.transport.write(slip_encode(packet))
134   ipif.transport.write(slip_delimiter)
135
136 #---------- SLIP handling ----------
137
138 slip_end = b'\300'
139 slip_esc = b'\333'
140 slip_esc_end = b'\334'
141 slip_esc_esc = b'\335'
142 slip_delimiter = slip_end
143
144 def slip_encode(packet):
145   return (packet
146           .replace(slip_esc, slip_esc + slip_esc_esc)
147           .replace(slip_end, slip_esc + slip_esc_end))
148
149 def slip_decode(data):
150   print('DECODE ', repr(data))
151   out = []
152   for packet in data.split(slip_end):
153     pdata = b''
154     while True:
155       eix = packet.find(slip_esc)
156       if eix == -1:
157         pdata += packet
158         break
159       #print('ESC ', repr((pdata, packet, eix)))
160       pdata += packet[0 : eix]
161       ck = packet[eix+1]
162       if   ck == slip_esc_esc: pdata += slip_esc
163       elif ck == slip_esc_end: pdata += slip_end
164       else: raise ValueError('invalid SLIP escape')
165       packet = packet[eix+2 : ]
166     out.append(pdata)
167   print('DECODED ', repr(out))
168   return out
169
170 #---------- packet parsing ----------
171
172 def packet_addrs(packet):
173   version = packet[0] >> 4
174   if version == 4:
175     addrlen = 4
176     saddroff = 3*4
177     factory = ipaddress.IPv4Address
178   elif version == 6:
179     addrlen = 16
180     saddroff = 2*4
181     factory = ipaddress.IPv6Address
182   else:
183     raise ValueError('unsupported IP version %d' % version)
184   saddr = factory(packet[ saddroff           : saddroff + addrlen   ])
185   daddr = factory(packet[ saddroff + addrlen : saddroff + addrlen*2 ])
186   return (saddr, daddr)
187
188 #---------- client ----------
189
190 class Client():
191   def __init__(self, ip, cs):
192     # instance data members
193     self._ip = ip
194     self._cs = cs
195     self.pw = cfg.get(cs, 'password')
196     self._rq = collections.deque() # requests
197     self._pq = collections.deque() # packets
198     # plus from config:
199     #  .max_batch_down
200     #  .max_queue_time
201     #  .max_request_time
202     for k in ('max_batch_down','max_queue_time','max_request_time'):
203       req = cfg.getint(cs, k)
204       limit = cfg.getint('limits',k)
205       self.__dict__[k] = min(req, limit)
206
207     def process_arriving_data(self, d):
208       for packet in slip_decode(d):
209         (saddr, daddr) = packet_addrs(packet)
210         if saddr != self._ip:
211           raise ValueError('wrong source address %s' % saddr)
212         route(packet, saddr, daddr)
213
214     def _req_cancel(self, request):
215       request.finish()
216
217     def _req_error(self, err, request):
218       self._req_cancel(request)
219
220     def queue_outbound(self, packet):
221       self._pq.append((time.monotonic(), packet))
222
223     def http_request(self, request):
224       request.setHeader('Content-Type','application/octet-stream')
225       reactor.callLater(self.max_request_time, self._req_cancel, request)
226       request.notifyFinish().addErrback(self._req_error, request)
227       self._rq.append(request)
228       self._check_outbound()
229
230     def _check_outbound(self):
231       while True:
232         try: request = self._rq[0]
233         except IndexError: request = None
234         if request and request.finished:
235           self._rq.popleft()
236           continue
237
238         # now request is an unfinished request, or None
239         try: (queuetime, packet) = self._pq[0]
240         except IndexError:
241           # no packets, oh well
242           break
243
244         age = time.monotonic() - queuetime
245         if age > self.max_queue_time:
246           self._pq.popleft()
247           continue
248
249         if request is None:
250           # no request
251           break
252
253         # request, and also some non-expired packets
254         while True:
255           try: (dummy, packet) = self._pq[0]
256           except IndexError: break
257
258           encoded = slip_encode(packet)
259           
260           if request.sentLength > 0:
261             if (request.sentLength + len(slip_delimiter)
262                 + len(encoded) > self.max_batch_down):
263               break
264             request.write(slip_delimiter)
265
266           request.write(encoded)
267           self._pq.popLeft()
268
269         assert(request.sentLength)
270         self._rq.popLeft()
271         request.finish()
272         # round again, looking for more to do
273
274 class IphttpResource(twisted.web.resource.Resource):
275   def render_POST(self, request):
276     # find client, update config, etc.
277     ci = ipaddr(request.args['i'])
278     c = clients[ci]
279     pw = request.args['pw']
280     if pw != c.pw: raise ValueError('bad password')
281
282     # update config
283     for r, w in (('mbd', 'max_batch_down'),
284                  ('mqt', 'max_queue_time'),
285                  ('mrt', 'max_request_time')):
286       try: v = request.args[r]
287       except KeyError: continue
288       v = int(v)
289       c.__dict__[w] = v
290
291     try: d = request.args['d']
292     except KeyError: d = ''
293
294     c.process_arriving_data(d)
295     c.new_request(request)
296
297 def start_http():
298   resource = IphttpResource()
299   sitefactory = twisted.web.server.Site(resource)
300   for addrspec in cfg.get('server','addrs').split():
301     try:
302       addr = ipaddress.IPv4Address(addrspec)
303       endpointfactory = twisted.internet.endpoints.TCP4ServerEndpoint
304     except AddressValueError:
305       addr = ipaddress.IPv6Address(addrspec)
306       endpointfactory = twisted.internet.endpoints.TCP6ServerEndpoint
307     ep = endpointfactory(reactor, cfg.getint('server','port'), addr)
308     crash_on_defer(ep.listen(sitefactory))
309
310 #---------- config and setup ----------
311         
312 def process_cfg():
313   global network
314   global host
315   global relay
316   global ipif_command
317
318   network = ipnetwork(cfg.get('virtual','network'))
319   if network.num_addresses < 3 + 2:
320     raise ValueError('network needs at least 2^3 addresses')
321
322   try:
323     host = cfg.get('virtual','host')
324   except NoOptionError:
325     host = next(network.hosts())
326
327   try:
328     relay = cfg.get('virtual','relay')
329   except NoOptionError:
330     for search in network.hosts():
331       if search == host: continue
332       relay = search
333       break
334
335   for cs in cfg.sections():
336     if not (':' in cs or '.' in cs): continue
337     ci = ipaddr(cs)
338     if ci not in network:
339       raise ValueError('client %s not in network' % ci)
340     if ci in clients:
341       raise ValueError('multiple client cfg sections for %s' % ci)
342     clients[ci] = Client(ci, cs)
343
344   global mtu
345   mtu = cfg.get('virtual','mtu')
346
347   iic_vars = { }
348   for k in ('host','relay','mtu','network'):
349     iic_vars[k] = globals()[k]
350
351   ipif_command = cfg.get('server','ipif', vars=iic_vars)
352
353 def startup():
354   global cfg
355
356   op = OptionParser()
357   op.add_option('-c', '--config', dest='configfile',
358                 default='/etc/hippottd/server.conf')
359   global opts
360   (opts, args) = op.parse_args()
361   if len(args): op.error('no non-option arguments please')
362
363   twisted.logger.globalLogPublisher.addObserver(crash_on_critical)
364
365   cfg = ConfigParser()
366   cfg.read_string(defcfg)
367   cfg.read(opts.configfile)
368   process_cfg()
369
370   start_ipif()
371   start_http()
372
373 startup()
374 reactor.run()
375 print('CRASHED (end)', file=sys.stderr)