chiark / gitweb /
wip
[hippotat.git] / server
1 #!/usr/bin/python3
2
3 import signal
4 signal.signal(signal.SIGINT, signal.SIG_DFL)
5
6 import sys
7 import os
8
9 import twisted
10 import twisted.internet
11 import twisted.internet.endpoints
12 from twisted.internet import reactor
13 from twisted.web.server import NOT_DONE_YET
14 from twisted.logger import LogLevel
15
16 import ipaddress
17 from ipaddress import AddressValueError
18
19 #import twisted.web.server import Site
20 #from twisted.web.resource import Resource
21
22 from optparse import OptionParser
23 from configparser import ConfigParser
24 from configparser import NoOptionError
25
26 import collections
27
28 import syslog
29
30 clients = { }
31
32 def ipaddr(input):
33   try:
34     r = ipaddress.IPv4Address(input)
35   except AddressValueError:
36     r = ipaddress.IPv6Address(input)
37   return r
38
39 def ipnetwork(input):
40   try:
41     r = ipaddress.IPv4Network(input)
42   except NetworkValueError:
43     r = ipaddress.IPv6Network(input)
44   return r
45
46 defcfg = '''
47 [DEFAULT]
48 max_batch_down = 65536
49 max_queue_time = 10
50 max_request_time = 54
51
52 [virtual]
53 mtu = 1500
54 # network
55 # [host]
56 # [relay]
57
58 [server]
59 ipif = userv root ipif %(host)s,%(relay)s,%(mtu)s,slip %(network)s
60 addrs = 127.0.0.1 ::1
61 port = 8099
62
63 [limits]
64 max_batch_down = 262144
65 max_queue_time = 121
66 max_request_time = 121
67 '''
68
69 #---------- error handling ----------
70
71 def crash(err):
72   print('CRASH ', err, file=sys.stderr)
73   try: reactor.stop()
74   except twisted.internet.error.ReactorNotRunning: pass
75
76 def crash_on_defer(defer):
77   defer.addErrback(lambda err: crash(err))
78
79 def crash_on_critical(event):
80   if event.get('log_level') >= LogLevel.critical:
81     crash(twisted.logger.formatEvent(event))
82
83 #---------- "router" ----------
84
85 def route(packet, saddr, daddr):
86   print('TRACE ', saddr, daddr, packet)
87   try: client = clients[daddr]
88   except KeyError: dclient = None
89   if dclient is not None:
90     dclient.queue_outbound(packet)
91   elif saddr.is_link_local or daddr.is_link_local:
92     log_discard(packet, saddr, daddr, 'link-local')
93   elif daddr == host or daddr not in network:
94     print('TRACE INBOUND ', saddr, daddr, packet)
95     queue_inbound(packet)
96   elif daddr == relay:
97     log_discard(packet, saddr, daddr, 'relay')
98   else:
99     log_discard(packet, saddr, daddr, 'no client')
100
101 def log_discard(packet, saddr, daddr, why):
102   print('DROP ', saddr, daddr, why)
103 #  syslog.syslog(syslog.LOG_DEBUG,
104 #                'discarded packet %s -> %s (%s)' % (saddr, daddr, why))
105
106 #---------- ipif (slip subprocess) ----------
107
108 class IpifProcessProtocol(twisted.internet.protocol.ProcessProtocol):
109   def __init__(self):
110     self._buffer = b''
111   def connectionMade(self): pass
112   def outReceived(self, data):
113     #print('RECV ', repr(data))
114     self._buffer += data
115     packets = slip_decode(self._buffer)
116     self._buffer = packets.pop()
117     for packet in packets:
118       if not len(packet): continue
119       (saddr, daddr) = packet_addrs(packet)
120       route(packet, saddr, daddr)
121   def processEnded(self, status):
122     status.raiseException()
123
124 def start_ipif():
125   global ipif
126   ipif = IpifProcessProtocol()
127   reactor.spawnProcess(ipif,
128                        '/bin/sh',['sh','-xc', ipif_command],
129                        childFDs={0:'w', 1:'r', 2:2})
130
131 def queue_inbound(packet):
132   ipif.transport.write(slip_delimiter)
133   ipif.transport.write(slip_encode(packet))
134   ipif.transport.write(slip_delimiter)
135
136 #---------- SLIP handling ----------
137
138 slip_end = b'\300'
139 slip_esc = b'\333'
140 slip_esc_end = b'\334'
141 slip_esc_esc = b'\335'
142 slip_delimiter = slip_end
143
144 def slip_encode(packet):
145   return (packet
146           .replace(slip_esc, slip_esc + slip_esc_esc)
147           .replace(slip_end, slip_esc + slip_esc_end))
148
149 def slip_decode(data):
150   print('DECODE ', repr(data))
151   out = []
152   for packet in data.split(slip_end):
153     pdata = b''
154     while True:
155       eix = packet.find(slip_esc)
156       if eix == -1:
157         pdata += packet
158         break
159       #print('ESC ', repr((pdata, packet, eix)))
160       pdata += packet[0 : eix]
161       ck = packet[eix+1]
162       #print('ESC... %o' % ck)
163       if   ck == slip_esc_esc[0]: pdata += slip_esc
164       elif ck == slip_esc_end[0]: pdata += slip_end
165       else: raise ValueError('invalid SLIP escape')
166       packet = packet[eix+2 : ]
167     out.append(pdata)
168   print('DECODED ', repr(out))
169   return out
170
171 #---------- packet parsing ----------
172
173 def packet_addrs(packet):
174   version = packet[0] >> 4
175   if version == 4:
176     addrlen = 4
177     saddroff = 3*4
178     factory = ipaddress.IPv4Address
179   elif version == 6:
180     addrlen = 16
181     saddroff = 2*4
182     factory = ipaddress.IPv6Address
183   else:
184     raise ValueError('unsupported IP version %d' % version)
185   saddr = factory(packet[ saddroff           : saddroff + addrlen   ])
186   daddr = factory(packet[ saddroff + addrlen : saddroff + addrlen*2 ])
187   return (saddr, daddr)
188
189 #---------- client ----------
190
191 class Client():
192   def __init__(self, ip, cs):
193     # instance data members
194     self._ip = ip
195     self._cs = cs
196     self.pw = cfg.get(cs, 'password')
197     self._rq = collections.deque() # requests
198     self._pq = collections.deque() # packets
199     # plus from config:
200     #  .max_batch_down
201     #  .max_queue_time
202     #  .max_request_time
203     for k in ('max_batch_down','max_queue_time','max_request_time'):
204       req = cfg.getint(cs, k)
205       limit = cfg.getint('limits',k)
206       self.__dict__[k] = min(req, limit)
207
208     def process_arriving_data(self, d):
209       for packet in slip_decode(d):
210         (saddr, daddr) = packet_addrs(packet)
211         if saddr != self._ip:
212           raise ValueError('wrong source address %s' % saddr)
213         route(packet, saddr, daddr)
214
215     def _req_cancel(self, request):
216       request.finish()
217
218     def _req_error(self, err, request):
219       self._req_cancel(request)
220
221     def queue_outbound(self, packet):
222       self._pq.append((time.monotonic(), packet))
223
224     def http_request(self, request):
225       request.setHeader('Content-Type','application/octet-stream')
226       reactor.callLater(self.max_request_time, self._req_cancel, request)
227       request.notifyFinish().addErrback(self._req_error, request)
228       self._rq.append(request)
229       self._check_outbound()
230
231     def _check_outbound(self):
232       while True:
233         try: request = self._rq[0]
234         except IndexError: request = None
235         if request and request.finished:
236           self._rq.popleft()
237           continue
238
239         # now request is an unfinished request, or None
240         try: (queuetime, packet) = self._pq[0]
241         except IndexError:
242           # no packets, oh well
243           break
244
245         age = time.monotonic() - queuetime
246         if age > self.max_queue_time:
247           self._pq.popleft()
248           continue
249
250         if request is None:
251           # no request
252           break
253
254         # request, and also some non-expired packets
255         while True:
256           try: (dummy, packet) = self._pq[0]
257           except IndexError: break
258
259           encoded = slip_encode(packet)
260           
261           if request.sentLength > 0:
262             if (request.sentLength + len(slip_delimiter)
263                 + len(encoded) > self.max_batch_down):
264               break
265             request.write(slip_delimiter)
266
267           request.write(encoded)
268           self._pq.popLeft()
269
270         assert(request.sentLength)
271         self._rq.popLeft()
272         request.finish()
273         # round again, looking for more to do
274
275 class IphttpResource(twisted.web.resource.Resource):
276   isLeaf = True
277   def render_POST(self, request):
278     # find client, update config, etc.
279     ci = ipaddr(request.args['i'])
280     c = clients[ci]
281     pw = request.args['pw']
282     if pw != c.pw: raise ValueError('bad password')
283
284     # update config
285     for r, w in (('mbd', 'max_batch_down'),
286                  ('mqt', 'max_queue_time'),
287                  ('mrt', 'max_request_time')):
288       try: v = request.args[r]
289       except KeyError: continue
290       v = int(v)
291       c.__dict__[w] = v
292
293     try: d = request.args['d']
294     except KeyError: d = ''
295
296     c.process_arriving_data(d)
297     c.new_request(request)
298
299   def render_GET(self, request):
300     return b'<html><body>hippotit</body></html>'
301
302 def start_http():
303   resource = IphttpResource()
304   site = twisted.web.server.Site(resource)
305   for addrspec in cfg.get('server','addrs').split():
306     try:
307       addr = ipaddress.IPv4Address(addrspec)
308       endpointfactory = twisted.internet.endpoints.TCP4ServerEndpoint
309     except AddressValueError:
310       addr = ipaddress.IPv6Address(addrspec)
311       endpointfactory = twisted.internet.endpoints.TCP6ServerEndpoint
312     ep = endpointfactory(reactor, cfg.getint('server','port'), addr)
313     crash_on_defer(ep.listen(site))
314
315 #---------- config and setup ----------
316         
317 def process_cfg():
318   global network
319   global host
320   global relay
321   global ipif_command
322
323   network = ipnetwork(cfg.get('virtual','network'))
324   if network.num_addresses < 3 + 2:
325     raise ValueError('network needs at least 2^3 addresses')
326
327   try:
328     host = cfg.get('virtual','host')
329   except NoOptionError:
330     host = next(network.hosts())
331
332   try:
333     relay = cfg.get('virtual','relay')
334   except NoOptionError:
335     for search in network.hosts():
336       if search == host: continue
337       relay = search
338       break
339
340   for cs in cfg.sections():
341     if not (':' in cs or '.' in cs): continue
342     ci = ipaddr(cs)
343     if ci not in network:
344       raise ValueError('client %s not in network' % ci)
345     if ci in clients:
346       raise ValueError('multiple client cfg sections for %s' % ci)
347     clients[ci] = Client(ci, cs)
348
349   global mtu
350   mtu = cfg.get('virtual','mtu')
351
352   iic_vars = { }
353   for k in ('host','relay','mtu','network'):
354     iic_vars[k] = globals()[k]
355
356   ipif_command = cfg.get('server','ipif', vars=iic_vars)
357
358 def startup():
359   global cfg
360
361   op = OptionParser()
362   op.add_option('-c', '--config', dest='configfile',
363                 default='/etc/hippottd/server.conf')
364   global opts
365   (opts, args) = op.parse_args()
366   if len(args): op.error('no non-option arguments please')
367
368   twisted.logger.globalLogPublisher.addObserver(crash_on_critical)
369
370   cfg = ConfigParser()
371   cfg.read_string(defcfg)
372   cfg.read(opts.configfile)
373   process_cfg()
374
375   start_ipif()
376   start_http()
377
378 startup()
379 reactor.run()
380 print('CRASHED (end)', file=sys.stderr)