chiark / gitweb /
wip, and move PacketQueue
[hippotat.git] / hippotat / __init__.py
1 # -*- python -*-
2
3 import twisted
4 from twisted.internet import reactor
5
6 import ipaddress
7 from ipaddress import AddressValueError
8
9 import hippotat.slip as slip
10
11 #---------- packet parsing ----------
12
13 def packet_addrs(packet):
14   version = packet[0] >> 4
15   if version == 4:
16     addrlen = 4
17     saddroff = 3*4
18     factory = ipaddress.IPv4Address
19   elif version == 6:
20     addrlen = 16
21     saddroff = 2*4
22     factory = ipaddress.IPv6Address
23   else:
24     raise ValueError('unsupported IP version %d' % version)
25   saddr = factory(packet[ saddroff           : saddroff + addrlen   ])
26   daddr = factory(packet[ saddroff + addrlen : saddroff + addrlen*2 ])
27   return (saddr, daddr)
28
29 #---------- address handling ----------
30
31 def ipaddr(input):
32   try:
33     r = ipaddress.IPv4Address(input)
34   except AddressValueError:
35     r = ipaddress.IPv6Address(input)
36   return r
37
38 def ipnetwork(input):
39   try:
40     r = ipaddress.IPv4Network(input)
41   except NetworkValueError:
42     r = ipaddress.IPv6Network(input)
43   return r
44
45 #---------- ipif (SLIP) subprocess ----------
46
47 class _IpifProcessProtocol(twisted.internet.protocol.ProcessProtocol):
48   def __init__(self, router):
49     self._buffer = b''
50     self._router = router
51   def connectionMade(self): pass
52   def outReceived(self, data):
53     #print('RECV ', repr(data))
54     self._buffer += data
55     packets = slip.decode(self._buffer)
56     self._buffer = packets.pop()
57     for packet in packets:
58       if not len(packet): continue
59       (saddr, daddr) = packet_addrs(packet)
60       self._router(packet, saddr, daddr)
61   def processEnded(self, status):
62     status.raiseException()
63
64 def start_ipif(command, router):
65   global ipif
66   ipif = _IpifProcessProtocol(router)
67   reactor.spawnProcess(ipif,
68                        '/bin/sh',['sh','-xc', command],
69                        childFDs={0:'w', 1:'r', 2:2})
70
71 def queue_inbound(packet):
72   ipif.transport.write(slip.delimiter)
73   ipif.transport.write(slip.encode(packet))
74   ipif.transport.write(slip.delimiter)
75
76 #---------- packet queue ----------
77
78 class PacketQueue():
79   def __init__(self, max_queue_time):
80     self._max_queue_time = max_queue_time
81     self._pq = collections.deque() # packets
82
83   def append(self, packet):
84     self._pq.append((time.monotonic(), packet))
85
86   def nonempty(self):
87     while True:
88       try: (queuetime, packet) = self._pq[0]
89       except IndexError: return False
90
91       age = time.monotonic() - queuetime
92       if age > self.max_queue_time:
93         # strip old packets off the front
94         self._pq.popleft()
95         continue
96
97       return True
98
99   def popleft(self):
100     # caller must have checked nonempty
101     try: (dummy, packet) = self._pq[0]
102     except IndexError: return None
103     return packet