chiark / gitweb /
ipaddrset: Define __bool__ and make __nonzero__ an alias
[secnet.git] / ipaddrset.py
1 """IP address set manipulation, built on top of ipaddr.py"""
2
3 # This file is Free Software.  It was originally written for secnet.
4 #
5 # Copyright 2014 Ian Jackson
6 #
7 # You may redistribute secnet as a whole and/or modify it under the
8 # terms of the GNU General Public License as published by the Free
9 # Software Foundation; either version 3, or (at your option) any
10 # later version.
11 #
12 # You may redistribute this file and/or modify it under the terms of
13 # the GNU General Public License as published by the Free Software
14 # Foundation; either version 2, or (at your option) any later version.
15 # Note however that this version of ipaddrset.py uses the Python
16 # ipaddr library from Google, which is licenced only under the Apache
17 # Licence, version 2.0, which is only compatible with the GNU GPL v3
18 # (or perhaps later versions), and not with the GNU GPL v2.
19 #
20 # This software is distributed in the hope that it will be useful,
21 # but WITHOUT ANY WARRANTY; without even the implied warranty of
22 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
23 # GNU General Public License for more details.
24 #
25 # You should have received a copy of the GNU General Public License
26 # along with this software; if not, see
27 # https://www.gnu.org/licenses/gpl.html.
28
29 import ipaddr
30
31 _vsns = [6,4]
32
33 class IPAddressSet:
34         "A set of IP addresses"
35
36         # constructors
37         def __init__(self,l=[]):
38                 "New set contains each ipaddr.IPNetwork in the sequence l"
39                 self._v = {}
40                 for v in _vsns:
41                         self._v[v] = [ ]
42                 self.append(l)
43
44         # housekeeping and representation
45         def _compact(self):
46                 for v in _vsns:
47                         self._v[v] = ipaddr.collapse_address_list(self._v[v])
48         def __repr__(self):
49                 return "IPAddressSet(%s)" % self.networks()
50         def str(self,comma=",",none="-"):
51                 "Human-readable string with controllable delimiters"
52                 if self:
53                         return comma.join(map(str, self.networks()))
54                 else:
55                         return none
56         def __str__(self):
57                 return self.str()
58
59         # mutators
60         def append(self,l):
61                 "Appends each ipaddr.IPNetwork in the sequence l to self"
62                 self._append(l)
63                 self._compact()
64
65         def _append(self,l):
66                 "Appends each ipaddr.IPNetwork in the sequence l to self"
67                 for a in l:
68                         self._v[a.version].append(a)
69
70         # enquirers including standard comparisons
71         def __bool__(self):
72                 for v in _vsns:
73                         if self._v[v]:
74                                 return True
75                 return False
76         __nonzero__=__bool__ # for python2
77
78         def __eq__(self,other):
79                 for v in _vsns:
80                         if self._v[v] != other._v[v]:
81                                 return False
82                 return True
83         def __ne__(self,other): return not self.__eq__(other)
84         def __ge__(self,other):
85                 """True iff self completely contains IPAddressSet other"""
86                 for o in other:
87                         if not self._contains_net(o):
88                                 return False
89                 return True
90         def __le__(self,other): return other.__ge__(self)
91         def __gt__(self,other): return self!=other and other.__ge__(self)
92         def __lt__(self,other): return other.__gt__(self)
93
94         def __cmp__(self,other):
95                 if self==other: return 0
96                 if self>=other: return +1
97                 if self<=other: return -1
98                 return NotImplemented
99
100         def __iter__(self):
101                 "Iterates over minimal list of distinct IPNetworks in this set"
102                 for v in _vsns:
103                         for i in self._v[v]:
104                                 yield i
105
106         def networks(self):
107                 "Returns miminal list of distinct IPNetworks in this set"
108                 return [i for i in self]
109
110         # set operations
111         def intersection(self,other):
112                 "Returns the intersection; does not modify self"
113                 r = IPAddressSet()
114                 for v in _vsns:
115                         for i in self._v[v]:
116                                 for j in other._v[v]:
117                                         if i.overlaps(j):
118                                                 if i.prefixlen > j.prefixlen:
119                                                         r._append([i])
120                                                 else:
121                                                         r._append([j])
122                 return r
123         def union(self,other):
124                 "Returns the union; does not modify self"
125                 r = IPAddressSet()
126                 r._append(self.networks())
127                 r._append(other.networks())
128                 r._compact()
129                 return r
130
131         def _contains_net(self,n):
132                 """True iff self completely contains IPNetwork n"""
133                 for i in self:
134                         if i.overlaps(n) and n.prefixlen >= i.prefixlen:
135                                 return True
136                 return False
137
138         def contains(self,thing):
139                 """Returns True iff self completely contains thing.
140                    thing may be an IPNetwork or an IPAddressSet"""
141                 try:
142                         v = [thing.version]
143                 except KeyError:
144                         v = None
145                 if v:
146                         return self._contains_net(ipaddr.IPNetwork(thing))
147                 else:
148                         return self.__ge__(thing)
149
150 def complete_set():
151         "Returns a set containing all addresses"
152         s=IPAddressSet()
153         for v in _vsns:
154                 a=ipaddr.IPAddress(0,v)
155                 n=ipaddr.IPNetwork("%s/0" % a)
156                 s.append([n])
157         return s