chiark / gitweb /
introduce c. etc.
[hippotat.git] / hippotat / __init__.py
1 # -*- python -*-
2
3 import signal
4 signal.signal(signal.SIGINT, signal.SIG_DFL)
5
6 import twisted
7 from twisted.internet import reactor
8 from twisted.logger import LogLevel
9
10 import ipaddress
11 from ipaddress import AddressValueError
12
13 import hippotat.slip as slip
14
15 from optparse import OptionParser
16 from configparser import ConfigParser
17 from configparser import NoOptionError
18
19 import collections
20
21 # these need to be defined here so that they can be imported by import *
22 cfg = ConfigParser()
23 optparser = OptionParser()
24
25 class ConfigResults:
26   def __init__(self, d = { }):
27     self.__dict__ = d
28   def __repr__(self):
29     return 'ConfigResults('+repr(self.__dict__)+')'
30
31 c = ConfigResults()
32
33 #---------- packet parsing ----------
34
35 def packet_addrs(packet):
36   version = packet[0] >> 4
37   if version == 4:
38     addrlen = 4
39     saddroff = 3*4
40     factory = ipaddress.IPv4Address
41   elif version == 6:
42     addrlen = 16
43     saddroff = 2*4
44     factory = ipaddress.IPv6Address
45   else:
46     raise ValueError('unsupported IP version %d' % version)
47   saddr = factory(packet[ saddroff           : saddroff + addrlen   ])
48   daddr = factory(packet[ saddroff + addrlen : saddroff + addrlen*2 ])
49   return (saddr, daddr)
50
51 #---------- address handling ----------
52
53 def ipaddr(input):
54   try:
55     r = ipaddress.IPv4Address(input)
56   except AddressValueError:
57     r = ipaddress.IPv6Address(input)
58   return r
59
60 def ipnetwork(input):
61   try:
62     r = ipaddress.IPv4Network(input)
63   except NetworkValueError:
64     r = ipaddress.IPv6Network(input)
65   return r
66
67 #---------- ipif (SLIP) subprocess ----------
68
69 class _IpifProcessProtocol(twisted.internet.protocol.ProcessProtocol):
70   def __init__(self, router):
71     self._buffer = b''
72     self._router = router
73   def connectionMade(self): pass
74   def outReceived(self, data):
75     #print('RECV ', repr(data))
76     self._buffer += data
77     packets = slip.decode(self._buffer)
78     self._buffer = packets.pop()
79     for packet in packets:
80       if not len(packet): continue
81       (saddr, daddr) = packet_addrs(packet)
82       self._router(packet, saddr, daddr)
83   def processEnded(self, status):
84     status.raiseException()
85
86 def start_ipif(command, router):
87   global ipif
88   ipif = _IpifProcessProtocol(router)
89   reactor.spawnProcess(ipif,
90                        '/bin/sh',['sh','-xc', command],
91                        childFDs={0:'w', 1:'r', 2:2})
92
93 def queue_inbound(packet):
94   ipif.transport.write(slip.delimiter)
95   ipif.transport.write(slip.encode(packet))
96   ipif.transport.write(slip.delimiter)
97
98 #---------- packet queue ----------
99
100 class PacketQueue():
101   def __init__(self, max_queue_time):
102     self._max_queue_time = max_queue_time
103     self._pq = collections.deque() # packets
104
105   def append(self, packet):
106     self._pq.append((time.monotonic(), packet))
107
108   def nonempty(self):
109     while True:
110       try: (queuetime, packet) = self._pq[0]
111       except IndexError: return False
112
113       age = time.monotonic() - queuetime
114       if age > self.max_queue_time:
115         # strip old packets off the front
116         self._pq.popleft()
117         continue
118
119       return True
120
121   def popleft(self):
122     # caller must have checked nonempty
123     try: (dummy, packet) = self._pq[0]
124     except IndexError: return None
125     return packet
126
127 #---------- error handling ----------
128
129 def crash(err):
130   print('CRASH ', err, file=sys.stderr)
131   try: reactor.stop()
132   except twisted.internet.error.ReactorNotRunning: pass
133
134 def crash_on_defer(defer):
135   defer.addErrback(lambda err: crash(err))
136
137 def crash_on_critical(event):
138   if event.get('log_level') >= LogLevel.critical:
139     crash(twisted.logger.formatEvent(event))
140
141 #---------- config processing ----------
142
143 def process_cfg_common_always():
144   global mtu
145   c.mtu = cfg.get('virtual','mtu')
146
147 #---------- startup ----------
148
149 def common_startup(defcfg):
150   twisted.logger.globalLogPublisher.addObserver(crash_on_critical)
151
152   optparser.add_option('-c', '--config', dest='configfile',
153                        default='/etc/hippotat/config')
154   (opts, args) = optparser.parse_args()
155   if len(args): optparser.error('no non-option arguments please')
156
157   cfg.read_string(defcfg)
158   cfg.read(opts.configfile)
159
160 def common_run():
161   reactor.run()
162   print('CRASHED (end)', file=sys.stderr)