chiark / gitweb /
wip, config reorg
[hippotat.git] / hippotat / __init__.py
1 # -*- python -*-
2
3 import signal
4 signal.signal(signal.SIGINT, signal.SIG_DFL)
5
6 import twisted
7 from twisted.internet import reactor
8 from twisted.logger import LogLevel
9
10 import ipaddress
11 from ipaddress import AddressValueError
12
13 import hippotat.slip as slip
14
15 from optparse import OptionParser
16 from configparser import ConfigParser
17 from configparser import NoOptionError
18
19 import collections
20
21 # these need to be defined here so that they can be imported by import *
22 cfg = ConfigParser()
23 optparser = OptionParser()
24
25 class ConfigResults:
26   def __init__(self, d = { }):
27     self.__dict__ = d
28   def __repr__(self):
29     return 'ConfigResults('+repr(self.__dict__)+')'
30
31 c = ConfigResults()
32
33 #---------- packet parsing ----------
34
35 def packet_addrs(packet):
36   version = packet[0] >> 4
37   if version == 4:
38     addrlen = 4
39     saddroff = 3*4
40     factory = ipaddress.IPv4Address
41   elif version == 6:
42     addrlen = 16
43     saddroff = 2*4
44     factory = ipaddress.IPv6Address
45   else:
46     raise ValueError('unsupported IP version %d' % version)
47   saddr = factory(packet[ saddroff           : saddroff + addrlen   ])
48   daddr = factory(packet[ saddroff + addrlen : saddroff + addrlen*2 ])
49   return (saddr, daddr)
50
51 #---------- address handling ----------
52
53 def ipaddr(input):
54   try:
55     r = ipaddress.IPv4Address(input)
56   except AddressValueError:
57     r = ipaddress.IPv6Address(input)
58   return r
59
60 def ipnetwork(input):
61   try:
62     r = ipaddress.IPv4Network(input)
63   except NetworkValueError:
64     r = ipaddress.IPv6Network(input)
65   return r
66
67 #---------- ipif (SLIP) subprocess ----------
68
69 class _IpifProcessProtocol(twisted.internet.protocol.ProcessProtocol):
70   def __init__(self, router):
71     self._buffer = b''
72     self._router = router
73   def connectionMade(self): pass
74   def outReceived(self, data):
75     #print('RECV ', repr(data))
76     self._buffer += data
77     packets = slip.decode(self._buffer)
78     self._buffer = packets.pop()
79     for packet in packets:
80       if not len(packet): continue
81       (saddr, daddr) = packet_addrs(packet)
82       self._router(packet, saddr, daddr)
83   def processEnded(self, status):
84     status.raiseException()
85
86 def start_ipif(command, router):
87   global ipif
88   ipif = _IpifProcessProtocol(router)
89   reactor.spawnProcess(ipif,
90                        '/bin/sh',['sh','-xc', command],
91                        childFDs={0:'w', 1:'r', 2:2})
92
93 def queue_inbound(packet):
94   ipif.transport.write(slip.delimiter)
95   ipif.transport.write(slip.encode(packet))
96   ipif.transport.write(slip.delimiter)
97
98 #---------- packet queue ----------
99
100 class PacketQueue():
101   def __init__(self, max_queue_time):
102     self._max_queue_time = max_queue_time
103     self._pq = collections.deque() # packets
104
105   def append(self, packet):
106     self._pq.append((time.monotonic(), packet))
107
108   def nonempty(self):
109     while True:
110       try: (queuetime, packet) = self._pq[0]
111       except IndexError: return False
112
113       age = time.monotonic() - queuetime
114       if age > self.max_queue_time:
115         # strip old packets off the front
116         self._pq.popleft()
117         continue
118
119       return True
120
121   def popleft(self):
122     # caller must have checked nonempty
123     try: (dummy, packet) = self._pq[0]
124     except IndexError: return None
125     return packet
126
127 #---------- error handling ----------
128
129 def crash(err):
130   print('CRASH ', err, file=sys.stderr)
131   try: reactor.stop()
132   except twisted.internet.error.ReactorNotRunning: pass
133
134 def crash_on_defer(defer):
135   defer.addErrback(lambda err: crash(err))
136
137 def crash_on_critical(event):
138   if event.get('log_level') >= LogLevel.critical:
139     crash(twisted.logger.formatEvent(event))
140
141 #---------- config processing ----------
142
143 def process_cfg_common_always():
144   global mtu
145   c.mtu = cfg.get('virtual','mtu')
146
147 def process_cfg_ipif(section, varmap):
148   for d, s in varmap:
149     try: v = getattr(c, s)
150     except KeyError: pass
151     setattr(c, d, v)
152
153   print(repr(c))
154
155   c.ipif_command = cfg.get(section,'ipif', vars=c.__dict__)
156
157 def process_cfg_network():
158   c.network = ipnetwork(cfg.get('virtual','network'))
159   if c.network.num_addresses < 3 + 2:
160     raise ValueError('network needs at least 2^3 addresses')
161
162 def process_cfg_server():
163   try:
164     c.server = cfg.get('virtual','server')
165   except NoOptionError:
166     process_cfg_network()
167     c.server = next(c.network.hosts())
168
169 class ServerAddr():
170   def __init__(self, port, addrspec):
171     self.port = port
172     # also self.addr
173     try:
174       self.addr = ipaddress.IPv4Address(addrspec)
175       self._endpointfactory = twisted.internet.endpoints.TCP4ServerEndpoint
176       self._inurl = '%s'
177     except AddressValueError:
178       self.addr = ipaddress.IPv6Address(addrspec)
179       self._endpointfactory = twisted.internet.endpoints.TCP6ServerEndpoint
180       self._inurl = '[%s]'
181   def make_endpoint(self):
182     return self._endpointfactory(reactor, self.port, self.addr)
183   def url(self):
184     url = 'http://' + (self._inurl % self.addr)
185     if self.port != 80: url += ':%d' % self.port
186     url += '/'
187     return url
188     
189 def process_cfg_saddrs():
190   port = cfg.getint('server','port')
191
192   c.saddrs = [ ]
193   for addrspec in cfg.get('server','addrs').split():
194     sa = ServerAddr(port, addrspec)
195     c.saddrs.append(sa)
196
197 def process_cfg_clients(constructor):
198   c.clients = [ ]
199   for cs in cfg.sections():
200     if not (':' in cs or '.' in cs): continue
201     ci = ipaddr(cs)
202     pw = cfg.get(cs, 'password')
203     constructor(ci,cs,pw)
204
205 #---------- startup ----------
206
207 def common_startup(defcfg):
208   twisted.logger.globalLogPublisher.addObserver(crash_on_critical)
209
210   optparser.add_option('-c', '--config', dest='configfile',
211                        default='/etc/hippotat/config')
212   (opts, args) = optparser.parse_args()
213   if len(args): optparser.error('no non-option arguments please')
214
215   cfg.read_string(defcfg)
216   cfg.read(opts.configfile)
217
218 def common_run():
219   reactor.run()
220   print('CRASHED (end)', file=sys.stderr)