chiark / gitweb /
wip
[hippotat.git] / server
1 #!/usr/bin/python3
2
3 import sys
4 import os
5
6 import twisted
7 import twisted.internet
8 import twisted.internet.endpoints
9 from twisted.internet import reactor
10 from twisted.web.server import NOT_DONE_YET
11 from twisted.logger import LogLevel
12
13 import ipaddress
14 from ipaddress import AddressValueError
15
16 #import twisted.web.server import Site
17 #from twisted.web.resource import Resource
18
19 from optparse import OptionParser
20 from configparser import ConfigParser
21 from configparser import NoOptionError
22
23 import collections
24
25 import syslog
26
27 clients = { }
28
29 def ipaddr(input):
30   try:
31     r = ipaddress.IPv4Address(input)
32   except AddressValueError:
33     r = ipaddress.IPv6Address(input)
34   return r
35
36 def ipnetwork(input):
37   try:
38     r = ipaddress.IPv4Network(input)
39   except NetworkValueError:
40     r = ipaddress.IPv6Network(input)
41   return r
42
43 defcfg = '''
44 [DEFAULT]
45 max_batch_down = 65536
46 max_queue_time = 10
47 max_request_time = 54
48
49 [virtual]
50 mtu = 1500
51 # network
52 # [host]
53 # [relay]
54
55 [server]
56 ipif = userv root ipif %(host)s,%(relay)s,%(mtu)s,slip %(network)s
57 addrs = 127.0.0.1 ::1
58 port = 80
59
60 [limits]
61 max_batch_down = 262144
62 max_queue_time = 121
63 max_request_time = 121
64 '''
65
66 #---------- "router" ----------
67
68 def route(packet, saddr, daddr):
69   print('TRACE ', saddr, daddr, packet)
70   try: client = clients[daddr]
71   except KeyError: dclient = None
72   if dclient is not None:
73     dclient.queue_outbound(packet)
74   elif daddr.is_multicast:
75     log_discard(packet, saddr, daddr, 'multicast')
76   elif daddr.is_link_local:
77     log_discard(packet, saddr, daddr, 'link-local')
78   elif daddr == host or daddr not in network:
79     print('TRACE INBOUND ', saddr, daddr, packet)
80     queue_inbound(packet)
81   elif daddr == relay:
82     log_discard(packet, saddr, daddr, 'relay')
83   else:
84     log_discard(packet, saddr, daddr, 'no client')
85
86 def log_discard(packet, saddr, daddr, why):
87   print('DROP ', saddr, daddr, why, packet)
88 #  syslog.syslog(syslog.LOG_DEBUG,
89 #                'discarded packet %s -> %s (%s)' % (saddr, daddr, why))
90
91 #---------- ipif (slip subprocess) ----------
92
93 class IpifProcessProtocol(twisted.internet.protocol.ProcessProtocol):
94   def __init__(self):
95     self._buffer = b''
96   def connectionMade(self): pass
97   def outReceived(self, data):
98     #print('RECV ', repr(data))
99     self._buffer += data
100     packets = slip_decode(self._buffer)
101     self._buffer = packets.pop()
102     for packet in packets:
103       if not len(packet): continue
104       (saddr, daddr) = packet_addrs(packet)
105       route(packet, saddr, daddr)
106   def processEnded(self, status):
107     status.raiseException()
108
109 def start_ipif():
110   global ipif
111   ipif = IpifProcessProtocol()
112   reactor.spawnProcess(ipif,
113                        '/bin/sh',['sh','-xc', ipif_command],
114                        childFDs={0:'w', 1:'r', 2:2})
115
116 def queue_inbound(packet):
117   ipif.transport.write(slip_delimiter)
118   ipif.transport.write(slip_encode(packet))
119   ipif.transport.write(slip_delimiter)
120
121 #---------- SLIP handling ----------
122
123 slip_end = b'\300'
124 slip_esc = b'\333'
125 slip_esc_end = b'\334'
126 slip_esc_esc = b'\335'
127 slip_delimiter = slip_end
128
129 def slip_encode(packet):
130   return (packet
131           .replace(slip_esc, slip_esc + slip_esc_esc)
132           .replace(slip_end, slip_esc + slip_esc_end))
133
134 def slip_decode(data):
135   print('DECODE ', repr(data))
136   out = []
137   for packet in data.split(slip_end):
138     pdata = b''
139     while True:
140       eix = packet.find(slip_esc)
141       if eix == -1:
142         pdata += packet
143         break
144       #print('ESC ', repr((pdata, packet, eix)))
145       pdata += packet[0 : eix]
146       ck = packet[eix+1]
147       if   ck == slip_esc_esc: pdata += slip_esc
148       elif ck == slip_esc_end: pdata += slip_end
149       else: raise ValueError('invalid SLIP escape')
150       packet = packet[eix+2 : ]
151     out.append(pdata)
152   print('DECODED ', repr(out))
153   return out
154
155 #---------- packet parsing ----------
156
157 def packet_addrs(packet):
158   version = packet[0] >> 4
159   if version == 4:
160     addrlen = 4
161     saddroff = 3*4
162     factory = ipaddress.IPv4Address
163   elif version == 6:
164     addrlen = 16
165     saddroff = 2*4
166     factory = ipaddress.IPv6Address
167   else:
168     raise ValueError('unsupported IP version %d' % version)
169   saddr = factory(packet[ saddroff           : saddroff + addrlen   ])
170   daddr = factory(packet[ saddroff + addrlen : saddroff + addrlen*2 ])
171   return (saddr, daddr)
172
173 #---------- client ----------
174
175 class Client():
176   def __init__(self, ip, cs):
177     # instance data members
178     self._ip = ip
179     self._cs = cs
180     self.pw = cfg.get(cs, 'password')
181     self._rq = collections.deque() # requests
182     self._pq = collections.deque() # packets
183     # plus from config:
184     #  .max_batch_down
185     #  .max_queue_time
186     #  .max_request_time
187     for k in ('max_batch_down','max_queue_time','max_request_time'):
188       req = cfg.getint(cs, k)
189       limit = cfg.getint('limits',k)
190       self.__dict__[k] = min(req, limit)
191
192     def process_arriving_data(self, d):
193       for packet in slip_decode(d):
194         (saddr, daddr) = packet_addrs(packet)
195         if saddr != self._ip:
196           raise ValueError('wrong source address %s' % saddr)
197         route(packet, saddr, daddr)
198
199     def _req_cancel(self, request):
200       request.finish()
201
202     def _req_error(self, err, request):
203       self._req_cancel(request)
204
205     def queue_outbound(self, packet):
206       self._pq.append((time.monotonic(), packet))
207
208     def http_request(self, request):
209       request.setHeader('Content-Type','application/octet-stream')
210       reactor.callLater(self.max_request_time, self._req_cancel, request)
211       request.notifyFinish().addErrback(self._req_error, request)
212       self._rq.append(request)
213       self._check_outbound()
214
215     def _check_outbound(self):
216       while True:
217         try: request = self._rq[0]
218         except IndexError: request = None
219         if request and request.finished:
220           self._rq.popleft()
221           continue
222
223         # now request is an unfinished request, or None
224         try: (queuetime, packet) = self._pq[0]
225         except IndexError:
226           # no packets, oh well
227           break
228
229         age = time.monotonic() - queuetime
230         if age > self.max_queue_time:
231           self._pq.popleft()
232           continue
233
234         if request is None:
235           # no request
236           break
237
238         # request, and also some non-expired packets
239         while True:
240           try: (dummy, packet) = self._pq[0]
241           except IndexError: break
242
243           encoded = slip_encode(packet)
244           
245           if request.sentLength > 0:
246             if (request.sentLength + len(slip_delimiter)
247                 + len(encoded) > self.max_batch_down):
248               break
249             request.write(slip_delimiter)
250
251           request.write(encoded)
252           self._pq.popLeft()
253
254         assert(request.sentLength)
255         self._rq.popLeft()
256         request.finish()
257         # round again, looking for more to do
258
259 class IphttpResource(twisted.web.resource.Resource):
260   def render_POST(self, request):
261     # find client, update config, etc.
262     ci = ipaddr(request.args['i'])
263     c = clients[ci]
264     pw = request.args['pw']
265     if pw != c.pw: raise ValueError('bad password')
266
267     # update config
268     for r, w in (('mbd', 'max_batch_down'),
269                  ('mqt', 'max_queue_time'),
270                  ('mrt', 'max_request_time')):
271       try: v = request.args[r]
272       except KeyError: continue
273       v = int(v)
274       c.__dict__[w] = v
275
276     try: d = request.args['d']
277     except KeyError: d = ''
278
279     c.process_arriving_data(d)
280     c.new_request(request)
281
282 def start_http():
283   resource = IphttpResource()
284   sitefactory = twisted.web.server.Site(resource)
285   for addrspec in cfg.get('server','addrs').split():
286     try:
287       addr = ipaddress.IPv4Address(addrspec)
288       endpointfactory = twisted.internet.endpoints.TCP4ServerEndpoint
289     except AddressValueError:
290       addr = ipaddress.IPv6Address(addrspec)
291       endpointfactory = twisted.internet.endpoints.TCP6ServerEndpoint
292     ep = endpointfactory(reactor, cfg.getint('server','port'), addr)
293     ep.listen(sitefactory)
294
295 #---------- config and setup ----------
296         
297 def process_cfg():
298   global network
299   global host
300   global relay
301   global ipif_command
302
303   network = ipnetwork(cfg.get('virtual','network'))
304   if network.num_addresses < 3 + 2:
305     raise ValueError('network needs at least 2^3 addresses')
306
307   try:
308     host = cfg.get('virtual','host')
309   except NoOptionError:
310     host = next(network.hosts())
311
312   try:
313     relay = cfg.get('virtual','relay')
314   except NoOptionError:
315     for search in network.hosts():
316       if search == host: continue
317       relay = search
318       break
319
320   for cs in cfg.sections():
321     if not (':' in cs or '.' in cs): continue
322     ci = ipaddr(cs)
323     if ci not in network:
324       raise ValueError('client %s not in network' % ci)
325     if ci in clients:
326       raise ValueError('multiple client cfg sections for %s' % ci)
327     clients[ci] = Client(ci, cs)
328
329   global mtu
330   mtu = cfg.get('virtual','mtu')
331
332   iic_vars = { }
333   for k in ('host','relay','mtu','network'):
334     iic_vars[k] = globals()[k]
335
336   ipif_command = cfg.get('server','ipif', vars=iic_vars)
337
338 def crash_on_critical(event):
339   if event.get('log_level') >= LogLevel.critical:
340     print('crashing: ', twisted.logger.formatEvent(event), file=sys.stderr)
341     #print('crashing!', file=sys.stderr)
342     #os._exit(1)
343     try: reactor.stop()
344     except twisted.internet.error.ReactorNotRunning: pass
345
346 def startup():
347   global cfg
348
349   op = OptionParser()
350   op.add_option('-c', '--config', dest='configfile',
351                 default='/etc/hippottd/server.conf')
352   global opts
353   (opts, args) = op.parse_args()
354   if len(args): op.error('no non-option arguments please')
355
356   twisted.logger.globalLogPublisher.addObserver(crash_on_critical)
357
358   cfg = ConfigParser()
359   cfg.read_string(defcfg)
360   cfg.read(opts.configfile)
361   process_cfg()
362
363   start_ipif()
364   start_http()
365
366 startup()
367 reactor.run()