chiark / gitweb /
9987a1de084100b34d8c7f09405e083b9078c214
[hippotat.git] / server
1 #!/usr/bin/python3
2
3 import sys
4 import os
5
6 import twisted
7 import twisted.internet
8 import twisted.internet.endpoints
9 from twisted.internet import reactor
10 from twisted.web.server import NOT_DONE_YET
11 from twisted.logger import LogLevel
12
13 import ipaddress
14 from ipaddress import AddressValueError
15
16 #import twisted.web.server import Site
17 #from twisted.web.resource import Resource
18
19 from optparse import OptionParser
20 from configparser import ConfigParser
21 from configparser import NoOptionError
22
23 import collections
24
25 import syslog
26
27 clients = { }
28
29 def ipaddr(input):
30   try:
31     r = ipaddress.IPv4Address(input)
32   except AddressValueError:
33     r = ipaddress.IPv6Address(input)
34   return r
35
36 def ipnetwork(input):
37   try:
38     r = ipaddress.IPv4Network(input)
39   except NetworkValueError:
40     r = ipaddress.IPv6Network(input)
41   return r
42
43 defcfg = '''
44 [DEFAULT]
45 max_batch_down = 65536
46 max_queue_time = 10
47 max_request_time = 54
48
49 [virtual]
50 mtu = 1500
51 # network
52 # [host]
53 # [relay]
54
55 [server]
56 ipif = userv root ipif %(host)s,%(relay)s,%(mtu)s,slip %(network)s
57 addrs = 127.0.0.1 ::1
58 port = 80
59
60 [limits]
61 max_batch_down = 262144
62 max_queue_time = 121
63 max_request_time = 121
64 '''
65
66 #---------- "router" ----------
67
68 def route(packet, daddr):
69   try: client = clients[daddr]
70   except KeyError: dclient = None
71   if dclient is not None:
72     dclient.queue_outbound(packet)
73   elif daddr == host or daddr not in network:
74     queue_inbound(packet)
75   elif daddr == relay:
76     log_discard(packet, saddr, daddr, 'relay')
77   else:
78     log_discard(packet, saddr, daddr, 'no client')
79
80 def log_discard(packet, saddr, daddr, why):
81   syslog.syslog(syslog.LOG_DEBUG,
82                 'discarded packet %s -> %s (%s)' % (saddr, daddr, why))
83
84 #---------- ipif (slip subprocess) ----------
85
86 class IpifProcessProtocol(twisted.internet.protocol.ProcessProtocol):
87   def __init__(self):
88     self._buffer = b''
89   def connectionMade(self): pass
90   def outReceived(self, data):
91     buffer += data
92     packets = slip_decode(buffer)
93     buffer = packets.pop()
94     for packet in packets:
95       (saddr, daddr) = packet_addrs(packet)
96       route(packet, daddr)
97   def processEnded(self, status):
98     status.raiseException()
99
100 def start_ipif():
101   global ipif
102   ipif = IpifProcessProtocol()
103   reactor.spawnProcess(ipif,
104                        '/bin/sh',['sh','-c', ipif_command],
105                        childFDs={0:'w', 1:'r', 2:2})
106
107 def queue_inbound(packet):
108   ipif.transport.write(slip_delimiter)
109   ipif.transport.write(slip_encode(packet))
110   ipif.transport.write(slip_delimiter)
111
112 #---------- client ----------
113
114 class Client():
115   def __init__(self, ip, cs):
116     # instance data members
117     self._ip = ip
118     self._cs = cs
119     self.pw = cfg.get(cs, 'password')
120     self._rq = collections.deque() # requests
121     self._pq = collections.deque() # packets
122     # plus from config:
123     #  .max_batch_down
124     #  .max_queue_time
125     #  .max_request_time
126     for k in ('max_batch_down','max_queue_time','max_request_time'):
127       req = cfg.getint(cs, k)
128       limit = cfg.getint('limits',k)
129       self.__dict__[k] = min(req, limit)
130
131     def process_arriving_data(self, d):
132       for packet in slip_decode(d):
133         (saddr, daddr) = packet_addrs(packet)
134         if saddr != self._ip:
135           raise ValueError('wrong source address %s' % saddr)
136         route(packet, daddr)
137
138     def _req_cancel(self, request):
139       request.finish()
140
141     def _req_error(self, err, request):
142       self._req_cancel(request)
143
144     def queue_outbound(self, packet):
145       self._pq.append((time.monotonic(), packet))
146
147     def http_request(self, request):
148       request.setHeader('Content-Type','application/octet-stream')
149       reactor.callLater(self.max_request_time, self._req_cancel, request)
150       request.notifyFinish().addErrback(self._req_error, request)
151       self._rq.append(request)
152       self._check_outbound()
153
154     def _check_outbound(self):
155       while True:
156         try: request = self._rq[0]
157         except IndexError: request = None
158         if request and request.finished:
159           self._rq.popleft()
160           continue
161
162         # now request is an unfinished request, or None
163         try: (queuetime, packet) = self._pq[0]
164         except IndexError:
165           # no packets, oh well
166           break
167
168         age = time.monotonic() - queuetime
169         if age > self.max_queue_time:
170           self._pq.popleft()
171           continue
172
173         if request is None:
174           # no request
175           break
176
177         # request, and also some non-expired packets
178         while True:
179           try: (dummy, packet) = self._pq[0]
180           except IndexError: break
181
182           encoded = slip_encode(packet)
183           
184           if request.sentLength > 0:
185             if (request.sentLength + len(slip_delimiter)
186                 + len(encoded) > self.max_batch_down):
187               break
188             request.write(slip_delimiter)
189
190           request.write(encoded)
191           self._pq.popLeft()
192
193         assert(request.sentLength)
194         self._rq.popLeft()
195         request.finish()
196         # round again, looking for more to do
197
198 class IphttpResource(twisted.web.resource.Resource):
199   def render_POST(self, request):
200     # find client, update config, etc.
201     ci = ipaddr(request.args['i'])
202     c = clients[ci]
203     pw = request.args['pw']
204     if pw != c.pw: raise ValueError('bad password')
205
206     # update config
207     for r, w in (('mbd', 'max_batch_down'),
208                  ('mqt', 'max_queue_time'),
209                  ('mrt', 'max_request_time')):
210       try: v = request.args[r]
211       except KeyError: continue
212       v = int(v)
213       c.__dict__[w] = v
214
215     try: d = request.args['d']
216     except KeyError: d = ''
217
218     c.process_arriving_data(d)
219     c.new_request(request)
220
221 def start_http():
222   resource = IphttpResource()
223   sitefactory = twisted.web.server.Site(resource)
224   for addrspec in cfg.get('server','addrs').split():
225     try:
226       addr = ipaddress.IPv4Address(addrspec)
227       endpointfactory = twisted.internet.endpoints.TCP4ServerEndpoint
228     except AddressValueError:
229       addr = ipaddress.IPv6Address(addrspec)
230       endpointfactory = twisted.internet.endpoints.TCP6ServerEndpoint
231     ep = endpointfactory(reactor, cfg.getint('server','port'), addr)
232     ep.listen(sitefactory)
233
234 #---------- config and setup ----------
235         
236 def process_cfg():
237   global network
238   global host
239   global relay
240   global ipif_command
241
242   network = ipnetwork(cfg.get('virtual','network'))
243   if network.num_addresses < 3 + 2:
244     raise ValueError('network needs at least 2^3 addresses')
245
246   try:
247     host = cfg.get('virtual','host')
248   except NoOptionError:
249     host = next(network.hosts())
250
251   try:
252     relay = cfg.get('virtual','relay')
253   except NoOptionError:
254     for search in network.hosts():
255       if search == host: continue
256       relay = search
257       break
258
259   for cs in cfg.sections():
260     if not (':' in cs or '.' in cs): continue
261     ci = ipaddr(cs)
262     if ci not in network:
263       raise ValueError('client %s not in network' % ci)
264     if ci in clients:
265       raise ValueError('multiple client cfg sections for %s' % ci)
266     clients[ci] = Client(ci, cs)
267
268   global mtu
269   mtu = cfg.get('virtual','mtu')
270
271   iic_vars = { }
272   for k in ('host','relay','mtu','network'):
273     iic_vars[k] = globals()[k]
274
275   ipif_command = cfg.get('server','ipif', vars=iic_vars)
276
277 def crash_on_critical(event):
278   if event.get('log_level') >= LogLevel.critical:
279     print('crashing: ', twisted.logger.formatEvent(event), file=sys.stderr)
280     #print('crashing!', file=sys.stderr)
281     #os._exit(1)
282     try: reactor.stop()
283     except twisted.internet.error.ReactorNotRunning: pass
284
285 def startup():
286   global cfg
287
288   op = OptionParser()
289   op.add_option('-c', '--config', dest='configfile',
290                 default='/etc/hippottd/server.conf')
291   global opts
292   (opts, args) = op.parse_args()
293   if len(args): op.error('no non-option arguments please')
294
295   twisted.logger.globalLogPublisher.addObserver(crash_on_critical)
296
297   cfg = ConfigParser()
298   cfg.read_string(defcfg)
299   cfg.read(opts.configfile)
300   process_cfg()
301
302   start_ipif()
303   start_http()
304
305 startup()
306 reactor.run()