chiark / gitweb /
f51caf555bb0bc844be37d89d5757e3cd4e34c47
[hippotat.git] / hippotat / __init__.py
1 # -*- python -*-
2
3 import signal
4 signal.signal(signal.SIGINT, signal.SIG_DFL)
5
6 import twisted
7 from twisted.internet import reactor
8 from twisted.logger import LogLevel
9 import twisted.internet.endpoints
10
11 import ipaddress
12 from ipaddress import AddressValueError
13
14 import hippotat.slip as slip
15
16 from optparse import OptionParser
17 from configparser import ConfigParser
18 from configparser import NoOptionError
19
20 import collections
21
22 defcfg = '''
23 [DEFAULT]
24 #[<client>] overrides
25 max_batch_down = 65536           # used by server, subject to [limits]
26 max_queue_time = 10              # used by server, subject to [limits]
27 max_request_time = 54            # used by server, subject to [limits]
28 target_requests_outstanding = 3  # must match; subject to [limits] on server
29 max_requests_outstanding = 4     # used by client
30 max_batch_up = 4000              # used by client
31
32 #[server] or [<client>] overrides
33 ipif = userv root ipif %(local)s,%(peer)s,%(mtu)s,slip %(rnets)s
34 # extra interpolations:  %(local)s        %(peer)s          %(rnet)s
35 #  obtained   on server  [virtual]server  [virtual]relay    [virtual]network
36 #      from   on client  <client>         [virtual]server   [virtual]routes
37
38 [virtual]
39 mtu = 1500
40 routes = ''
41 # network = <prefix>/<len>  # mandatory for server
42 # server  = <ipaddr>   # used by both, default is computed from `network'
43 # relay   = <ipaddr>   # used by server, default from `network' and `server'
44 #  default server is first host in network
45 #  default relay is first host which is not server
46
47 [server]
48 # addrs = 127.0.0.1 ::1    # mandatory for server
49 port = 80                  # used by server
50 # url              # used by client; default from first `addrs' and `port'
51
52 # [<client-ip4-or-ipv6-address>]
53 # password = <password>    # used by both, must match
54
55 [limits]
56 max_batch_down = 262144           # used by server
57 max_queue_time = 121              # used by server
58 max_request_time = 121            # used by server
59 target_requests_outstanding = 10  # used by server
60 '''
61
62 # these need to be defined here so that they can be imported by import *
63 cfg = ConfigParser()
64 optparser = OptionParser()
65
66 class ConfigResults:
67   def __init__(self, d = { }):
68     self.__dict__ = d
69   def __repr__(self):
70     return 'ConfigResults('+repr(self.__dict__)+')'
71
72 c = ConfigResults()
73
74 #---------- packet parsing ----------
75
76 def packet_addrs(packet):
77   version = packet[0] >> 4
78   if version == 4:
79     addrlen = 4
80     saddroff = 3*4
81     factory = ipaddress.IPv4Address
82   elif version == 6:
83     addrlen = 16
84     saddroff = 2*4
85     factory = ipaddress.IPv6Address
86   else:
87     raise ValueError('unsupported IP version %d' % version)
88   saddr = factory(packet[ saddroff           : saddroff + addrlen   ])
89   daddr = factory(packet[ saddroff + addrlen : saddroff + addrlen*2 ])
90   return (saddr, daddr)
91
92 #---------- address handling ----------
93
94 def ipaddr(input):
95   try:
96     r = ipaddress.IPv4Address(input)
97   except AddressValueError:
98     r = ipaddress.IPv6Address(input)
99   return r
100
101 def ipnetwork(input):
102   try:
103     r = ipaddress.IPv4Network(input)
104   except NetworkValueError:
105     r = ipaddress.IPv6Network(input)
106   return r
107
108 #---------- ipif (SLIP) subprocess ----------
109
110 class _IpifProcessProtocol(twisted.internet.protocol.ProcessProtocol):
111   def __init__(self, router):
112     self._buffer = b''
113     self._router = router
114   def connectionMade(self): pass
115   def outReceived(self, data):
116     #print('IPIF-GOT ', repr(data))
117     self._buffer += data
118     packets = slip.decode(self._buffer)
119     self._buffer = packets.pop()
120     for packet in packets:
121       if not len(packet): continue
122       (saddr, daddr) = packet_addrs(packet)
123       if saddr.is_link_local or daddr.is_link_local:
124         log_discard(packet, saddr, daddr, 'link-local')
125         continue
126       self._router(packet, saddr, daddr)
127   def processEnded(self, status):
128     status.raiseException()
129
130 def start_ipif(command, router):
131   global ipif
132   ipif = _IpifProcessProtocol(router)
133   reactor.spawnProcess(ipif,
134                        '/bin/sh',['sh','-xc', command],
135                        childFDs={0:'w', 1:'r', 2:2})
136
137 def queue_inbound(packet):
138   ipif.transport.write(slip.delimiter)
139   ipif.transport.write(slip.encode(packet))
140   ipif.transport.write(slip.delimiter)
141
142 #---------- packet queue ----------
143
144 class PacketQueue():
145   def __init__(self, max_queue_time):
146     self._max_queue_time = max_queue_time
147     self._pq = collections.deque() # packets
148
149   def append(self, packet):
150     self._pq.append((time.monotonic(), packet))
151
152   def nonempty(self):
153     while True:
154       try: (queuetime, packet) = self._pq[0]
155       except IndexError: return False
156
157       age = time.monotonic() - queuetime
158       if age > self.max_queue_time:
159         # strip old packets off the front
160         self._pq.popleft()
161         continue
162
163       return True
164
165   def popleft(self):
166     # caller must have checked nonempty
167     try: (dummy, packet) = self._pq[0]
168     except IndexError: return None
169     return packet
170
171 #---------- error handling ----------
172
173 def crash(err):
174   print('CRASH ', err, file=sys.stderr)
175   try: reactor.stop()
176   except twisted.internet.error.ReactorNotRunning: pass
177
178 def crash_on_defer(defer):
179   defer.addErrback(lambda err: crash(err))
180
181 def crash_on_critical(event):
182   if event.get('log_level') >= LogLevel.critical:
183     crash(twisted.logger.formatEvent(event))
184
185 #---------- config processing ----------
186
187 def process_cfg_common_always():
188   global mtu
189   c.mtu = cfg.get('virtual','mtu')
190
191 def process_cfg_ipif(section, varmap):
192   for d, s in varmap:
193     try: v = getattr(c, s)
194     except AttributeError: continue
195     setattr(c, d, v)
196
197   print(repr(c))
198
199   c.ipif_command = cfg.get(section,'ipif', vars=c.__dict__)
200
201 def process_cfg_network():
202   c.network = ipnetwork(cfg.get('virtual','network'))
203   if c.network.num_addresses < 3 + 2:
204     raise ValueError('network needs at least 2^3 addresses')
205
206 def process_cfg_server():
207   try:
208     c.server = cfg.get('virtual','server')
209   except NoOptionError:
210     process_cfg_network()
211     c.server = next(c.network.hosts())
212
213 class ServerAddr():
214   def __init__(self, port, addrspec):
215     self.port = port
216     # also self.addr
217     try:
218       self.addr = ipaddress.IPv4Address(addrspec)
219       self._endpointfactory = twisted.internet.endpoints.TCP4ServerEndpoint
220       self._inurl = '%s'
221     except AddressValueError:
222       self.addr = ipaddress.IPv6Address(addrspec)
223       self._endpointfactory = twisted.internet.endpoints.TCP6ServerEndpoint
224       self._inurl = '[%s]'
225   def make_endpoint(self):
226     return self._endpointfactory(reactor, self.port, self.addr)
227   def url(self):
228     url = 'http://' + (self._inurl % self.addr)
229     if self.port != 80: url += ':%d' % self.port
230     url += '/'
231     return url
232     
233 def process_cfg_saddrs():
234   try: port = cfg.getint('server','port')
235   except NoOptionError: port = 80
236
237   c.saddrs = [ ]
238   for addrspec in cfg.get('server','addrs').split():
239     sa = ServerAddr(port, addrspec)
240     c.saddrs.append(sa)
241
242 def process_cfg_clients(constructor):
243   c.clients = [ ]
244   for cs in cfg.sections():
245     if not (':' in cs or '.' in cs): continue
246     ci = ipaddr(cs)
247     pw = cfg.get(cs, 'password')
248     constructor(ci,cs,pw)
249
250 #---------- startup ----------
251
252 def common_startup(defcfg):
253   twisted.logger.globalLogPublisher.addObserver(crash_on_critical)
254
255   optparser.add_option('-c', '--config', dest='configfile',
256                        default='/etc/hippotat/config')
257   (opts, args) = optparser.parse_args()
258   if len(args): optparser.error('no non-option arguments please')
259
260   cfg.read_string(defcfg)
261   cfg.read(opts.configfile)
262
263 def common_run():
264   reactor.run()
265   print('CRASHED (end)', file=sys.stderr)