chiark / gitweb /
make-secnet-sites: set_property: Break out kw
[secnet.git] / make-secnet-sites
1 #! /usr/bin/env python3
2 #
3 # This file is part of secnet.
4 # See README for full list of copyright holders.
5 #
6 # secnet is free software; you can redistribute it and/or modify it
7 # under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 3 of the License, or
9 # (at your option) any later version.
10
11 # secnet is distributed in the hope that it will be useful, but
12 # WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14 # General Public License for more details.
15
16 # You should have received a copy of the GNU General Public License
17 # version 3 along with secnet; if not, see
18 # https://www.gnu.org/licenses/gpl.html.
19
20 """VPN sites file manipulation.
21
22 This program enables VPN site descriptions to be submitted for
23 inclusion in a central database, and allows the resulting database to
24 be turned into a secnet configuration file.
25
26 A database file can be turned into a secnet configuration file simply:
27 make-secnet-sites.py [infile [outfile]]
28
29 It would be wise to run secnet with the "--just-check-config" option
30 before installing the output on a live system.
31
32 The program expects to be invoked via userv to manage the database; it
33 relies on the USERV_USER and USERV_GROUP environment variables. The
34 command line arguments for this invocation are:
35
36 make-secnet-sites.py -u header-filename groupfiles-directory output-file \
37   group
38
39 All but the last argument are expected to be set by userv; the 'group'
40 argument is provided by the user. A suitable userv configuration file
41 fragment is:
42
43 reset
44 no-disconnect-hup
45 no-suppress-args
46 cd ~/secnet/sites-test/
47 execute ~/secnet/make-secnet-sites.py -u vpnheader groupfiles sites
48
49 This program is part of secnet.
50
51 """
52
53 from __future__ import print_function
54 from __future__ import unicode_literals
55 from builtins import int
56
57 import string
58 import time
59 import sys
60 import os
61 import getopt
62 import re
63 import argparse
64 import math
65
66 import ipaddress
67
68 # entry 0 is "near the executable", or maybe from PYTHONPATH=.,
69 # which we don't want to preempt
70 sys.path.insert(1,"/usr/local/share/secnet")
71 sys.path.insert(1,"/usr/share/secnet")
72 import ipaddrset
73
74 from argparseactionnoyes import ActionNoYes
75
76 VERSION="0.1.18"
77
78 from sys import version_info
79 if version_info.major == 2:  # for python2
80     import codecs
81     sys.stdin = codecs.getreader('utf-8')(sys.stdin)
82     sys.stdout = codecs.getwriter('utf-8')(sys.stdout)
83     import io
84     open=lambda f,m='r': io.open(f,m,encoding='utf-8')
85
86 max={'rsa_bits':8200,'name':33,'dh_bits':8200}
87
88 class Tainted:
89         def __init__(self,s,tline=None,tfile=None):
90                 self._s=s
91                 self._ok=None
92                 self._line=line if tline is None else tline
93                 self._file=file if tfile is None else tfile
94         def __eq__(self,e):
95                 return self._s==e
96         def __ne__(self,e):
97                 # for Python2
98                 return not self.__eq__(e)
99         def __str__(self):
100                 raise RuntimeError('direct use of Tainted value')
101         def __repr__(self):
102                 return 'Tainted(%s)' % repr(self._s)
103
104         def _bad(self,what,why):
105                 assert(self._ok is not True)
106                 self._ok=False
107                 complain('bad parameter: %s: %s' % (what, why))
108                 return False
109
110         def _max_ok(self,what,maxlen):
111                 if len(self._s) > maxlen:
112                         return self._bad(what,'too long (max %d)' % maxlen)
113                 return True
114
115         def _re_ok(self,bad,what,maxlen=None):
116                 if maxlen is None: maxlen=max[what]
117                 self._max_ok(what,maxlen)
118                 if self._ok is False: return False
119                 if bad.search(self._s): return self._bad(what,'bad syntax')
120                 return True
121
122         def _rtnval(self, is_ok, ifgood, ifbad=''):
123                 if is_ok:
124                         assert(self._ok is not False)
125                         self._ok=True
126                         return ifgood
127                 else:
128                         assert(self._ok is not True)
129                         self._ok=False
130                         return ifbad
131
132         def _rtn(self, is_ok, ifbad=''):
133                 return self._rtnval(is_ok, self._s, ifbad)
134
135         def raw(self):
136                 return self._s
137         def raw_mark_ok(self):
138                 # caller promises to throw if syntax was dangeorus
139                 return self._rtn(True)
140
141         def output(self):
142                 if self._ok is False: return ''
143                 if self._ok is True: return self._s
144                 print('%s:%d: unchecked/unknown additional data "%s"' %
145                       (self._file,self._line,self._s),
146                       file=sys.stderr)
147                 sys.exit(1)
148
149         bad_name=re.compile(r'^[^a-zA-Z]|[^-_0-9a-zA-Z]')
150         # secnet accepts _ at start of names, but we reserve that
151         bad_name_counter=0
152         def name(self,what='name'):
153                 ok=self._re_ok(Tainted.bad_name,what)
154                 return self._rtn(ok,
155                                  '_line%d_%s' % (self._line, id(self)))
156
157         def keyword(self):
158                 ok=self._s in keywords or self._s in levels
159                 if not ok:
160                         complain('unknown keyword %s' % self._s)
161                 return self._rtn(ok)
162
163         bad_hex=re.compile(r'[^0-9a-fA-F]')
164         def bignum_16(self,kind,what):
165                 maxlen=(max[kind+'_bits']+3)/4
166                 ok=self._re_ok(Tainted.bad_hex,what,maxlen)
167                 return self._rtn(ok)
168
169         bad_num=re.compile(r'[^0-9]')
170         def bignum_10(self,kind,what):
171                 maxlen=math.ceil(max[kind+'_bits'] / math.log10(2))
172                 ok=self._re_ok(Tainted.bad_num,what,maxlen)
173                 return self._rtn(ok)
174
175         def number(self,minn,maxx,what='number'):
176                 # not for bignums
177                 ok=self._re_ok(Tainted.bad_num,what,10)
178                 if ok:
179                         v=int(self._s)
180                         if v<minn or v>maxx:
181                                 ok=self._bad(what,'out of range %d..%d'
182                                              % (minn,maxx))
183                 return self._rtnval(ok,v,minn)
184
185         def hexid(self,byteslen,what):
186                 ok=self._re_ok(Tainted.bad_hex,what,byteslen*2)
187                 if ok:
188                         if len(self._s) < byteslen*2:
189                                 ok=self._bad(what,'too short')
190                 return self._rtn(ok,ifbad='00'*byteslen)
191
192         bad_host=re.compile(r'[^-\][_.:0-9a-zA-Z]')
193         # We permit _ so we can refer to special non-host domains
194         # which have A and AAAA RRs.  This is a crude check and we may
195         # still produce config files with syntactically invalid
196         # domains or addresses, but that is OK.
197         def host(self):
198                 ok=self._re_ok(Tainted.bad_host,'host/address',255)
199                 return self._rtn(ok)
200
201         bad_email=re.compile(r'[^-._0-9a-z@!$%^&*=+~/]')
202         # ^ This does not accept all valid email addresses.  That's
203         # not really possible with this input syntax.  It accepts
204         # all ones that don't require quoting anywhere in email
205         # protocols (and also accepts some invalid ones).
206         def email(self):
207                 ok=self._re_ok(Tainted.bad_email,'email address',1023)
208                 return self._rtn(ok)
209
210         bad_groupname=re.compile(r'^[^_A-Za-z]|[^-+_0-9A-Za-z]')
211         def groupname(self):
212                 ok=self._re_ok(Tainted.bad_groupname,'group name',64)
213                 return self._rtn(ok)
214
215         bad_base91=re.compile(r'[^!-~]|[\'\"\\]')
216         def base91(self,what='base91'):
217                 ok=self._re_ok(Tainted.bad_base91,what,4096)
218                 return self._rtn(ok)
219
220 def parse_args():
221         global service
222         global inputfile
223         global header
224         global groupfiledir
225         global sitesfile
226         global outputfile
227         global group
228         global user
229         global of
230         global prefix
231         global key_prefix
232
233         ap = argparse.ArgumentParser(description='process secnet sites files')
234         ap.add_argument('--userv', '-u', action='store_true',
235                         help='userv service fragment update mode')
236         ap.add_argument('--conf-key-prefix', action=ActionNoYes,
237                         default=True,
238                  help='prefix conf file key names derived from sites data')
239         ap.add_argument('--prefix', '-P', nargs=1,
240                         help='set prefix')
241         ap.add_argument('arg',nargs=argparse.REMAINDER)
242         av = ap.parse_args()
243         #print(repr(av), file=sys.stderr)
244         service = 1 if av.userv else 0
245         prefix = '' if av.prefix is None else av.prefix[0]
246         key_prefix = av.conf_key_prefix
247         if service:
248                 if len(av.arg)!=4:
249                         print("Wrong number of arguments")
250                         sys.exit(1)
251                 (header, groupfiledir, sitesfile, group) = av.arg
252                 group = Tainted(group,0,'command line')
253                 # untrusted argument from caller
254                 if "USERV_USER" not in os.environ:
255                         print("Environment variable USERV_USER not found")
256                         sys.exit(1)
257                 user=os.environ["USERV_USER"]
258                 # Check that group is in USERV_GROUP
259                 if "USERV_GROUP" not in os.environ:
260                         print("Environment variable USERV_GROUP not found")
261                         sys.exit(1)
262                 ugs=os.environ["USERV_GROUP"]
263                 ok=0
264                 for i in ugs.split():
265                         if group==i: ok=1
266                 if not ok:
267                         print("caller not in group %s"%group)
268                         sys.exit(1)
269         else:
270                 if len(av.arg)>3:
271                         print("Too many arguments")
272                         sys.exit(1)
273                 (inputfile, outputfile) = (av.arg + [None]*2)[0:2]
274
275 parse_args()
276
277 # Classes describing possible datatypes in the configuration file
278
279 class basetype:
280         "Common protocol for configuration types."
281         def add(self,obj,w):
282                 complain("%s %s already has property %s defined"%
283                         (obj.type,obj.name,w[0].raw()))
284
285 class conflist:
286         "A list of some kind of configuration type."
287         def __init__(self,subtype,w):
288                 self.subtype=subtype
289                 self.list=[subtype(w)]
290         def add(self,obj,w):
291                 self.list.append(self.subtype(w))
292         def __str__(self):
293                 return ', '.join(map(str, self.list))
294 def listof(subtype):
295         return lambda w: conflist(subtype, w)
296
297 class single_ipaddr (basetype):
298         "An IP address"
299         def __init__(self,w):
300                 self.addr=ipaddress.ip_address(w[1].raw_mark_ok())
301         def __str__(self):
302                 return '"%s"'%self.addr
303
304 class networks (basetype):
305         "A set of IP addresses specified as a list of networks"
306         def __init__(self,w):
307                 self.set=ipaddrset.IPAddressSet()
308                 for i in w[1:]:
309                         x=ipaddress.ip_network(i.raw_mark_ok(),strict=True)
310                         self.set.append([x])
311         def __str__(self):
312                 return ",".join(map((lambda n: '"%s"'%n), self.set.networks()))
313
314 class dhgroup (basetype):
315         "A Diffie-Hellman group"
316         def __init__(self,w):
317                 self.mod=w[1].bignum_16('dh','dh mod')
318                 self.gen=w[2].bignum_16('dh','dh gen')
319         def __str__(self):
320                 return 'diffie-hellman("%s","%s")'%(self.mod,self.gen)
321
322 class hash (basetype):
323         "A choice of hash function"
324         def __init__(self,w):
325                 hname=w[1]
326                 self.ht=hname.raw()
327                 if (self.ht!='md5' and self.ht!='sha1'):
328                         complain("unknown hash type %s"%(self.ht))
329                         self.ht=None
330                 else:
331                         hname.raw_mark_ok()
332         def __str__(self):
333                 return '%s'%(self.ht)
334
335 class email (basetype):
336         "An email address"
337         def __init__(self,w):
338                 self.addr=w[1].email()
339         def __str__(self):
340                 return '<%s>'%(self.addr)
341
342 class boolean (basetype):
343         "A boolean"
344         def __init__(self,w):
345                 v=w[1]
346                 if re.match('[TtYy1]',v.raw()):
347                         self.b=True
348                         v.raw_mark_ok()
349                 elif re.match('[FfNn0]',v.raw()):
350                         self.b=False
351                         v.raw_mark_ok()
352                 else:
353                         complain("invalid boolean value");
354         def __str__(self):
355                 return ['False','True'][self.b]
356
357 class num (basetype):
358         "A decimal number"
359         def __init__(self,w):
360                 self.n=w[1].number(0,0x7fffffff)
361         def __str__(self):
362                 return '%d'%(self.n)
363
364 class address (basetype):
365         "A DNS name and UDP port number"
366         def __init__(self,w):
367                 self.adr=w[1].host()
368                 self.port=w[2].number(1,65536,'port')
369         def __str__(self):
370                 return '"%s"; port %d'%(self.adr,self.port)
371
372 class rsakey (basetype):
373         "An RSA public key"
374         def __init__(self,w):
375                 self.l=w[1].number(0,max['rsa_bits'],'rsa len')
376                 self.e=w[2].bignum_10('rsa','rsa e')
377                 self.n=w[3].bignum_10('rsa','rsa n')
378                 if len(w) >= 5: w[4].email()
379         def __str__(self):
380                 return 'rsa-public("%s","%s")'%(self.e,self.n)
381
382 # Possible properties of configuration nodes
383 keywords={
384  'contact':(email,"Contact address"),
385  'dh':(dhgroup,"Diffie-Hellman group"),
386  'hash':(hash,"Hash function"),
387  'key-lifetime':(num,"Maximum key lifetime (ms)"),
388  'setup-timeout':(num,"Key setup timeout (ms)"),
389  'setup-retries':(num,"Maximum key setup packet retries"),
390  'wait-time':(num,"Time to wait after unsuccessful key setup (ms)"),
391  'renegotiate-time':(num,"Time after key setup to begin renegotiation (ms)"),
392  'restrict-nets':(networks,"Allowable networks"),
393  'networks':(networks,"Claimed networks"),
394  'pubkey':(rsakey,"RSA public site key"),
395  'peer':(single_ipaddr,"Tunnel peer IP address"),
396  'address':(address,"External contact address and port"),
397  'mobile':(boolean,"Site is mobile"),
398 }
399
400 def sp(name,value):
401         "Simply output a property - the default case"
402         return "%s %s;\n"%(name,value)
403
404 # All levels support these properties
405 global_properties={
406         'contact':(lambda name,value:"# Contact email address: %s\n"%(value)),
407         'dh':sp,
408         'hash':sp,
409         'key-lifetime':sp,
410         'setup-timeout':sp,
411         'setup-retries':sp,
412         'wait-time':sp,
413         'renegotiate-time':sp,
414         'restrict-nets':(lambda name,value:"# restrict-nets %s\n"%value),
415 }
416
417 class level:
418         "A level in the configuration hierarchy"
419         depth=0
420         leaf=0
421         allow_properties={}
422         require_properties={}
423         def __init__(self,w):
424                 self.type=w[0].keyword()
425                 self.name=w[1].name()
426                 self.properties={}
427                 self.children={}
428         def indent(self,w,t):
429                 w.write("                 "[:t])
430         def prop_out(self,n):
431                 return self.allow_properties[n](n,str(self.properties[n]))
432         def output_props(self,w,ind):
433                 for i in sorted(self.properties.keys()):
434                         if self.allow_properties[i]:
435                                 self.indent(w,ind)
436                                 w.write("%s"%self.prop_out(i))
437         def kname(self):
438                 return ((self.type[0].upper() if key_prefix else '')
439                         + self.name)
440         def output_data(self,w,path):
441                 ind = 2*len(path)
442                 self.indent(w,ind)
443                 w.write("%s {\n"%(self.kname()))
444                 self.output_props(w,ind+2)
445                 if self.depth==1: w.write("\n");
446                 for k in sorted(self.children.keys()):
447                         c=self.children[k]
448                         c.output_data(w,path+(c,))
449                 self.indent(w,ind)
450                 w.write("};\n")
451
452 class vpnlevel(level):
453         "VPN level in the configuration hierarchy"
454         depth=1
455         leaf=0
456         type="vpn"
457         allow_properties=global_properties.copy()
458         require_properties={
459          'contact':"VPN admin contact address"
460         }
461         def __init__(self,w):
462                 level.__init__(self,w)
463         def output_vpnflat(self,w,path):
464                 "Output flattened list of site names for this VPN"
465                 ind=2*(len(path)+1)
466                 self.indent(w,ind)
467                 w.write("%s {\n"%(self.kname()))
468                 for i in self.children.keys():
469                         self.children[i].output_vpnflat(w,path+(self,))
470                 w.write("\n")
471                 self.indent(w,ind+2)
472                 w.write("all-sites %s;\n"%
473                         ','.join(map(lambda i: i.kname(),
474                                      self.children.values())))
475                 self.indent(w,ind)
476                 w.write("};\n")
477
478 class locationlevel(level):
479         "Location level in the configuration hierarchy"
480         depth=2
481         leaf=0
482         type="location"
483         allow_properties=global_properties.copy()
484         require_properties={
485          'contact':"Location admin contact address",
486         }
487         def __init__(self,w):
488                 level.__init__(self,w)
489                 self.group=w[2].groupname()
490         def output_vpnflat(self,w,path):
491                 ind=2*(len(path)+1)
492                 self.indent(w,ind)
493                 # The "path=path,self=self" abomination below exists because
494                 # Python didn't support nested_scopes until version 2.1
495                 #
496                 #"/"+self.name+"/"+i
497                 w.write("%s %s;\n"%(self.kname(),','.join(
498                         map(lambda x,path=path,self=self:
499                             '/'.join([prefix+"vpn-data"] + list(map(
500                                     lambda i: i.kname(),
501                                     path+(self,x)))),
502                             self.children.values()))))
503
504 class sitelevel(level):
505         "Site level (i.e. a leafnode) in the configuration hierarchy"
506         depth=3
507         leaf=1
508         type="site"
509         allow_properties=global_properties.copy()
510         allow_properties.update({
511          'address':sp,
512          'networks':None,
513          'peer':None,
514          'pubkey':(lambda n,v:"key %s;\n"%v),
515          'mobile':sp,
516         })
517         require_properties={
518          'dh':"Diffie-Hellman group",
519          'contact':"Site admin contact address",
520          'networks':"Networks claimed by the site",
521          'hash':"hash function",
522          'peer':"Gateway address of the site",
523          'pubkey':"RSA public key of the site",
524         }
525         def __init__(self,w):
526                 level.__init__(self,w)
527         def output_data(self,w,path):
528                 ind=2*len(path)
529                 np='/'.join(map(lambda i: i.name, path))
530                 self.indent(w,ind)
531                 w.write("%s {\n"%(self.kname()))
532                 self.indent(w,ind+2)
533                 w.write("name \"%s\";\n"%(np,))
534                 self.output_props(w,ind+2)
535                 self.indent(w,ind+2)
536                 w.write("link netlink {\n");
537                 self.indent(w,ind+4)
538                 w.write("routes %s;\n"%str(self.properties["networks"]))
539                 self.indent(w,ind+4)
540                 w.write("ptp-address %s;\n"%str(self.properties["peer"]))
541                 self.indent(w,ind+2)
542                 w.write("};\n")
543                 self.indent(w,ind)
544                 w.write("};\n")
545
546 # Levels in the configuration file
547 # (depth,properties)
548 levels={'vpn':vpnlevel, 'location':locationlevel, 'site':sitelevel}
549
550 def complain(msg):
551         "Complain about a particular input line"
552         moan(("%s line %d: "%(file,line))+msg)
553 def moan(msg):
554         "Complain about something in general"
555         global complaints
556         print(msg);
557         if complaints is None: sys.exit(1)
558         complaints=complaints+1
559
560 class UntaintedRoot():
561         def __init__(self,s): self._s=s
562         def name(self): return self._s
563         def keyword(self): return self._s
564
565 root=level([UntaintedRoot(x) for x in ['root','root']])
566 # All vpns are children of this node
567 obstack=[root]
568 allow_defs=0   # Level above which new definitions are permitted
569
570 def set_property(obj,w):
571         "Set a property on a configuration node"
572         prop=w[0]
573         kw=keywords[prop.raw_mark_ok()]
574         if prop.raw() in obj.properties:
575                 obj.properties[prop.raw()].add(obj,w)
576         else:
577                 obj.properties[prop.raw()]=kw[0](w)
578
579
580 def pline(il,allow_include=False):
581         "Process a configuration file line"
582         global allow_defs, obstack, root
583         w=il.rstrip('\n').split()
584         if len(w)==0: return ['']
585         w=list([Tainted(x) for x in w])
586         keyword=w[0]
587         current=obstack[len(obstack)-1]
588         copyout=lambda: ['    '*len(obstack) +
589                         ' '.join([ww.output() for ww in w]) +
590                         '\n']
591         if keyword=='end-definitions':
592                 keyword.raw_mark_ok()
593                 allow_defs=sitelevel.depth
594                 obstack=[root]
595                 return copyout()
596         if keyword=='include':
597                 if not allow_include:
598                         complain("include not permitted here")
599                         return []
600                 if len(w) != 2:
601                         complain("include requires one argument")
602                         return []
603                 newfile=os.path.join(os.path.dirname(file),w[1].raw_mark_ok())
604                 # ^ user of "include" is trusted so raw_mark_ok is good
605                 return pfilepath(newfile,allow_include=allow_include)
606         if keyword.raw() in levels:
607                 # We may go up any number of levels, but only down by one
608                 newdepth=levels[keyword.raw_mark_ok()].depth
609                 currentdepth=len(obstack) # actually +1...
610                 if newdepth<=currentdepth:
611                         obstack=obstack[:newdepth]
612                 if newdepth>currentdepth:
613                         complain("May not go from level %d to level %d"%
614                                 (currentdepth-1,newdepth))
615                 # See if it's a new one (and whether that's permitted)
616                 # or an existing one
617                 current=obstack[len(obstack)-1]
618                 tname=w[1].name()
619                 if tname in current.children:
620                         # Not new
621                         current=current.children[tname]
622                         if service and group and current.depth==2:
623                                 if group!=current.group:
624                                         complain("Incorrect group!")
625                                 w[2].groupname()
626                 else:
627                         # New
628                         # Ignore depth check for now
629                         nl=levels[keyword.raw()](w)
630                         if nl.depth<allow_defs:
631                                 complain("New definitions not allowed at "
632                                         "level %d"%nl.depth)
633                                 # we risk crashing if we continue
634                                 sys.exit(1)
635                         current.children[tname]=nl
636                         current=nl
637                 obstack.append(current)
638                 return copyout()
639         if keyword.raw() not in current.allow_properties:
640                 complain("Property %s not allowed at %s level"%
641                         (keyword.raw(),current.type))
642                 return []
643         elif current.depth == vpnlevel.depth < allow_defs:
644                 complain("Not allowed to set VPN properties here")
645                 return []
646         else:
647                 set_property(current,w)
648                 return copyout()
649
650         complain("unknown keyword '%s'"%(keyword.raw()))
651
652 def pfilepath(pathname,allow_include=False):
653         f=open(pathname)
654         outlines=pfile(pathname,f.readlines(),allow_include=allow_include)
655         f.close()
656         return outlines
657
658 def pfile(name,lines,allow_include=False):
659         "Process a file"
660         global file,line
661         file=name
662         line=0
663         outlines=[]
664         for i in lines:
665                 line=line+1
666                 if (i[0]=='#'): continue
667                 outlines += pline(i,allow_include=allow_include)
668         return outlines
669
670 def outputsites(w):
671         "Output include file for secnet configuration"
672         w.write("# secnet sites file autogenerated by make-secnet-sites "
673                 +"version %s\n"%VERSION)
674         w.write("# %s\n"%time.asctime(time.localtime(time.time())))
675         w.write("# Command line: %s\n\n"%' '.join(sys.argv))
676
677         # Raw VPN data section of file
678         w.write(prefix+"vpn-data {\n")
679         for i in root.children.values():
680                 i.output_data(w,(i,))
681         w.write("};\n")
682
683         # Per-VPN flattened lists
684         w.write(prefix+"vpn {\n")
685         for i in root.children.values():
686                 i.output_vpnflat(w,())
687         w.write("};\n")
688
689         # Flattened list of sites
690         w.write(prefix+"all-sites %s;\n"%",".join(
691                 map(lambda x:"%svpn/%s/all-sites"%(prefix,x.kname()),
692                         root.children.values())))
693
694 line=0
695 file=None
696 complaints=0
697
698 # Sanity check section
699 # Delete nodes where leaf=0 that have no children
700
701 def live(n):
702         "Number of leafnodes below node n"
703         if n.leaf: return 1
704         for i in n.children.keys():
705                 if live(n.children[i]): return 1
706         return 0
707 def delempty(n):
708         "Delete nodes that have no leafnode children"
709         for i in list(n.children.keys()):
710                 delempty(n.children[i])
711                 if not live(n.children[i]):
712                         del n.children[i]
713
714 # Check that all constraints are met (as far as I can tell
715 # restrict-nets/networks/peer are the only special cases)
716
717 def checkconstraints(n,p,ra):
718         new_p=p.copy()
719         new_p.update(n.properties)
720         for i in n.require_properties.keys():
721                 if i not in new_p:
722                         moan("%s %s is missing property %s"%
723                                 (n.type,n.name,i))
724         for i in new_p.keys():
725                 if i not in n.allow_properties:
726                         moan("%s %s has forbidden property %s"%
727                                 (n.type,n.name,i))
728         # Check address range restrictions
729         if "restrict-nets" in n.properties:
730                 new_ra=ra.intersection(n.properties["restrict-nets"].set)
731         else:
732                 new_ra=ra
733         if "networks" in n.properties:
734                 if not n.properties["networks"].set <= new_ra:
735                         moan("%s %s networks out of bounds"%(n.type,n.name))
736                 if "peer" in n.properties:
737                         if not n.properties["networks"].set.contains(
738                                 n.properties["peer"].addr):
739                                 moan("%s %s peer not in networks"%(n.type,n.name))
740         for i in n.children.keys():
741                 checkconstraints(n.children[i],new_p,new_ra)
742
743 if service:
744         headerinput=pfilepath(header,allow_include=True)
745         userinput=sys.stdin.readlines()
746         pfile("user input",userinput)
747 else:
748         if inputfile is None:
749                 pfile("stdin",sys.stdin.readlines())
750         else:
751                 pfilepath(inputfile)
752
753 delempty(root)
754 checkconstraints(root,{},ipaddrset.complete_set())
755
756 if complaints>0:
757         if complaints==1: print("There was 1 problem.")
758         else: print("There were %d problems."%(complaints))
759         sys.exit(1)
760 complaints=None # arranges to crash if we complain later
761
762 if service:
763         # Put the user's input into their group file, and rebuild the main
764         # sites file
765         f=open(groupfiledir+"/T"+group.groupname(),'w')
766         f.write("# Section submitted by user %s, %s\n"%
767                 (user,time.asctime(time.localtime(time.time()))))
768         f.write("# Checked by make-secnet-sites version %s\n\n"%VERSION)
769         for i in userinput: f.write(i)
770         f.write("\n")
771         f.close()
772         os.rename(groupfiledir+"/T"+group.groupname(),
773                   groupfiledir+"/R"+group.groupname())
774         f=open(sitesfile+"-tmp",'w')
775         f.write("# sites file autogenerated by make-secnet-sites\n")
776         f.write("# generated %s, invoked by %s\n"%
777                 (time.asctime(time.localtime(time.time())),user))
778         f.write("# use make-secnet-sites to turn this file into a\n")
779         f.write("# valid /etc/secnet/sites.conf file\n\n")
780         for i in headerinput: f.write(i)
781         files=os.listdir(groupfiledir)
782         for i in files:
783                 if i[0]=='R':
784                         j=open(groupfiledir+"/"+i)
785                         f.write(j.read())
786                         j.close()
787         f.write("# end of sites file\n")
788         f.close()
789         os.rename(sitesfile+"-tmp",sitesfile)
790 else:
791         if outputfile is None:
792                 of=sys.stdout
793         else:
794                 tmp_outputfile=outputfile+'~tmp~'
795                 of=open(tmp_outputfile,'w')
796         outputsites(of)
797         if outputfile is not None:
798                 os.rename(tmp_outputfile,outputfile)