chiark / gitweb /
make-secnet-sites: Introduce new OpMod classes
[secnet.git] / make-secnet-sites
1 #! /usr/bin/env python3
2 #
3 # This file is part of secnet.
4 # See README for full list of copyright holders.
5 #
6 # secnet is free software; you can redistribute it and/or modify it
7 # under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 3 of the License, or
9 # (at your option) any later version.
10
11 # secnet is distributed in the hope that it will be useful, but
12 # WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14 # General Public License for more details.
15
16 # You should have received a copy of the GNU General Public License
17 # version 3 along with secnet; if not, see
18 # https://www.gnu.org/licenses/gpl.html.
19
20 """VPN sites file manipulation.
21
22 This program enables VPN site descriptions to be submitted for
23 inclusion in a central database, and allows the resulting database to
24 be turned into a secnet configuration file.
25
26 A database file can be turned into a secnet configuration file simply:
27 make-secnet-sites.py [infile [outfile]]
28
29 It would be wise to run secnet with the "--just-check-config" option
30 before installing the output on a live system.
31
32 The program expects to be invoked via userv to manage the database; it
33 relies on the USERV_USER and USERV_GROUP environment variables. The
34 command line arguments for this invocation are:
35
36 make-secnet-sites.py -u header-filename groupfiles-directory output-file \
37   group
38
39 All but the last argument are expected to be set by userv; the 'group'
40 argument is provided by the user. A suitable userv configuration file
41 fragment is:
42
43 reset
44 no-disconnect-hup
45 no-suppress-args
46 cd ~/secnet/sites-test/
47 execute ~/secnet/make-secnet-sites.py -u vpnheader groupfiles sites
48
49 This program is part of secnet.
50
51 """
52
53 from __future__ import print_function
54 from __future__ import unicode_literals
55 from builtins import int
56
57 import string
58 import time
59 import sys
60 import os
61 import getopt
62 import re
63 import argparse
64 import math
65
66 import ipaddress
67
68 # entry 0 is "near the executable", or maybe from PYTHONPATH=.,
69 # which we don't want to preempt
70 sys.path.insert(1,"/usr/local/share/secnet")
71 sys.path.insert(1,"/usr/share/secnet")
72 import ipaddrset
73 import base91
74
75 from argparseactionnoyes import ActionNoYes
76
77 VERSION="0.1.18"
78
79 max_version = 2
80
81 from sys import version_info
82 if version_info.major == 2:  # for python2
83     import codecs
84     sys.stdin = codecs.getreader('utf-8')(sys.stdin)
85     sys.stdout = codecs.getwriter('utf-8')(sys.stdout)
86     import io
87     open=lambda f,m='r': io.open(f,m,encoding='utf-8')
88
89 max={'rsa_bits':8200,'name':33,'dh_bits':8200,'algname':127}
90
91 def debugrepr(*args):
92         if debug_level > 0:
93                 print(repr(args), file=sys.stderr)
94
95 def base91s_encode(bindata):
96         return base91.encode(bindata).replace('"',"-")
97
98 def base91s_decode(string):
99         return base91.decode(string.replace("-",'"'))
100
101 class Tainted:
102         def __init__(self,s,tline=None,tfile=None):
103                 self._s=s
104                 self._ok=None
105                 self._line=line if tline is None else tline
106                 self._file=file if tfile is None else tfile
107         def __eq__(self,e):
108                 return self._s==e
109         def __ne__(self,e):
110                 # for Python2
111                 return not self.__eq__(e)
112         def __str__(self):
113                 raise RuntimeError('direct use of Tainted value')
114         def __repr__(self):
115                 return 'Tainted(%s)' % repr(self._s)
116
117         def _bad(self,what,why):
118                 assert(self._ok is not True)
119                 self._ok=False
120                 complain('bad parameter: %s: %s' % (what, why))
121                 return False
122
123         def _max_ok(self,what,maxlen):
124                 if len(self._s) > maxlen:
125                         return self._bad(what,'too long (max %d)' % maxlen)
126                 return True
127
128         def _re_ok(self,bad,what,maxlen=None):
129                 if maxlen is None: maxlen=max[what]
130                 self._max_ok(what,maxlen)
131                 if self._ok is False: return False
132                 if bad.search(self._s):
133                         #print(repr(self), file=sys.stderr)
134                         return self._bad(what,'bad syntax')
135                 return True
136
137         def _rtnval(self, is_ok, ifgood, ifbad=''):
138                 if is_ok:
139                         assert(self._ok is not False)
140                         self._ok=True
141                         return ifgood
142                 else:
143                         assert(self._ok is not True)
144                         self._ok=False
145                         return ifbad
146
147         def _rtn(self, is_ok, ifbad=''):
148                 return self._rtnval(is_ok, self._s, ifbad)
149
150         def raw(self):
151                 return self._s
152         def raw_mark_ok(self):
153                 # caller promises to throw if syntax was dangeorus
154                 return self._rtn(True)
155
156         def output(self):
157                 if self._ok is False: return ''
158                 if self._ok is True: return self._s
159                 print('%s:%d: unchecked/unknown additional data "%s"' %
160                       (self._file,self._line,self._s),
161                       file=sys.stderr)
162                 sys.exit(1)
163
164         bad_name=re.compile(r'^[^a-zA-Z]|[^-_0-9a-zA-Z]')
165         # secnet accepts _ at start of names, but we reserve that
166         bad_name_counter=0
167         def name(self,what='name'):
168                 ok=self._re_ok(Tainted.bad_name,what)
169                 return self._rtn(ok,
170                                  '_line%d_%s' % (self._line, id(self)))
171
172         def keyword(self):
173                 ok=self._s in keywords or self._s in levels
174                 if not ok:
175                         complain('unknown keyword %s' % self._s)
176                 return self._rtn(ok)
177
178         bad_hex=re.compile(r'[^0-9a-fA-F]')
179         def bignum_16(self,kind,what):
180                 maxlen=(max[kind+'_bits']+3)/4
181                 ok=self._re_ok(Tainted.bad_hex,what,maxlen)
182                 return self._rtn(ok)
183
184         bad_num=re.compile(r'[^0-9]')
185         def bignum_10(self,kind,what):
186                 maxlen=math.ceil(max[kind+'_bits'] / math.log10(2))
187                 ok=self._re_ok(Tainted.bad_num,what,maxlen)
188                 return self._rtn(ok)
189
190         def number(self,minn,maxx,what='number'):
191                 # not for bignums
192                 ok=self._re_ok(Tainted.bad_num,what,10)
193                 if ok:
194                         v=int(self._s)
195                         if v<minn or v>maxx:
196                                 ok=self._bad(what,'out of range %d..%d'
197                                              % (minn,maxx))
198                 return self._rtnval(ok,v,minn)
199
200         def hexid(self,byteslen,what):
201                 ok=self._re_ok(Tainted.bad_hex,what,byteslen*2)
202                 if ok:
203                         if len(self._s) < byteslen*2:
204                                 ok=self._bad(what,'too short')
205                 return self._rtn(ok,ifbad='00'*byteslen)
206
207         bad_host=re.compile(r'[^-\][_.:0-9a-zA-Z]')
208         # We permit _ so we can refer to special non-host domains
209         # which have A and AAAA RRs.  This is a crude check and we may
210         # still produce config files with syntactically invalid
211         # domains or addresses, but that is OK.
212         def host(self):
213                 ok=self._re_ok(Tainted.bad_host,'host/address',255)
214                 return self._rtn(ok)
215
216         bad_email=re.compile(r'[^-._0-9a-z@!$%^&*=+~/]')
217         # ^ This does not accept all valid email addresses.  That's
218         # not really possible with this input syntax.  It accepts
219         # all ones that don't require quoting anywhere in email
220         # protocols (and also accepts some invalid ones).
221         def email(self):
222                 ok=self._re_ok(Tainted.bad_email,'email address',1023)
223                 return self._rtn(ok)
224
225         bad_groupname=re.compile(r'^[^_A-Za-z]|[^-+_0-9A-Za-z]')
226         def groupname(self):
227                 ok=self._re_ok(Tainted.bad_groupname,'group name',64)
228                 return self._rtn(ok)
229
230         bad_base91=re.compile(r'[^!-~]|[\'\"\\]')
231         def base91(self,what='base91'):
232                 ok=self._re_ok(Tainted.bad_base91,what,4096)
233                 return self._rtn(ok)
234
235 class ArgActionLambda(argparse.Action):
236         def __init__(self, fn, **kwargs):
237                 self.fn=fn
238                 argparse.Action.__init__(self,**kwargs)
239         def __call__(self,ap,ns,values,option_string):
240                 self.fn(values,ns,ap,option_string)
241
242 class PkmBase():
243         def site_start(self,pubkeys_path):
244                 self._pa=pubkeys_path
245                 self._fs = FilterState()
246         def site_serial(self,serial): pass
247         def write_key(self,k): pass
248         def site_finish(self,confw): pass
249
250 class PkmSingle(PkmBase):
251         opt = 'single'
252         help = 'write one public key per site to sites.conf'
253         def site_start(self,pubkeys_path):
254                 PkmBase.site_start(self,pubkeys_path)
255                 self._outk = []
256         def write_key(self,k):
257                 if k.okforonlykey(output_version,self._fs):
258                         self._outk.append(k)
259         def site_finish(self,confw):
260                 if len(self._outk) == 0:
261                         complain("site with no public key");
262                 elif len(self._outk) != 1:
263                         debugrepr('outk ', self._outk)
264                         complain(
265  "site with multiple public keys, without --pubkeys-install (maybe --output-version=1 would help"
266                         )
267                 else:
268                         confw.write("key %s;\n"%str(self._outk[0]))
269
270 class PkmInstall(PkmBase):
271         opt = 'install'
272         help = 'install public keys in public key directory'
273         def site_start(self,pubkeys_path):
274                 PkmBase.site_start(self,pubkeys_path)
275                 self._pw=open(self._pa+'~tmp','w')
276         def site_serial(self,serial):
277                 self._pw.write('serial %s\n' % serial)
278         def write_key(self,k):
279                 wout=k.forpub(output_version,self._fs)
280                 self._pw.write(' '.join(wout))
281                 self._pw.write('\n')
282         def site_finish(self,confw):
283                 self._pw.close()
284                 os.rename(self._pa+'~tmp',self._pa+'~update')
285                 PkmElide.site_finish(self,confw)
286
287 class PkmElide(PkmBase):
288         opt = 'elide'
289         help = 'no public keys in sites.conf output nor in directory'
290         def site_finish(self,confw):
291                 confw.write("peer-keys \"%s\";\n"%self._pa);
292
293 class OpBase():
294         pass
295
296 class OpConf(OpBase):
297         def is_service(self): return 0
298
299 class OpUserv(OpBase):
300         opts = ['--userv','-u']
301         help = 'userv service fragment update mode'
302         def is_service(self): return 1
303
304 def parse_args():
305         global opmode
306         global service
307         global inputfile
308         global header
309         global groupfiledir
310         global sitesfile
311         global outputfile
312         global group
313         global user
314         global of
315         global prefix
316         global key_prefix
317         global debug_level
318         global output_version
319         global pubkeys_dir
320         global pubkeys_mode
321
322         ap = argparse.ArgumentParser(description='process secnet sites files')
323         def add_opmode(how):
324                 ap.add_argument(*how().opts, action=ArgActionLambda,
325                         nargs=0,
326                         fn=(lambda v,ns,*x: setattr(ns,'opmode',how)),
327                         help=how().help)
328         add_opmode(OpUserv)
329         ap.add_argument('--conf-key-prefix', action=ActionNoYes,
330                         default=True,
331                  help='prefix conf file key names derived from sites data')
332         def add_pkm(how):
333                 ap.add_argument('--pubkeys-'+how().opt, action=ArgActionLambda,
334                         nargs=0,
335                         fn=(lambda v,ns,*x: setattr(ns,'pkm',how)),
336                         help=how().help)
337         add_pkm(PkmInstall)
338         add_pkm(PkmSingle)
339         add_pkm(PkmElide)
340         ap.add_argument('--pubkeys-dir',  nargs=1,
341                         help='public key directory',
342                         default=['/var/lib/secnet/pubkeys'])
343         ap.add_argument('--output-version', nargs=1, type=int,
344                         help='sites file output version',
345                         default=[max_version])
346         ap.add_argument('--prefix', '-P', nargs=1,
347                         help='set prefix')
348         ap.add_argument('--debug', '-D', action='count', default=0)
349         ap.add_argument('arg',nargs=argparse.REMAINDER)
350         av = ap.parse_args()
351         debug_level = av.debug
352         debugrepr('av',av)
353         opmode = getattr(av,'opmode',OpConf)()
354         service = opmode.is_service()
355         prefix = '' if av.prefix is None else av.prefix[0]
356         key_prefix = av.conf_key_prefix
357         output_version = av.output_version[0]
358         pubkeys_dir = av.pubkeys_dir[0]
359         pubkeys_mode = getattr(av,'pkm',PkmSingle)
360         if service:
361                 if len(av.arg)!=4:
362                         print("Wrong number of arguments")
363                         sys.exit(1)
364                 (header, groupfiledir, sitesfile, group) = av.arg
365                 group = Tainted(group,0,'command line')
366                 # untrusted argument from caller
367                 if "USERV_USER" not in os.environ:
368                         print("Environment variable USERV_USER not found")
369                         sys.exit(1)
370                 user=os.environ["USERV_USER"]
371                 # Check that group is in USERV_GROUP
372                 if "USERV_GROUP" not in os.environ:
373                         print("Environment variable USERV_GROUP not found")
374                         sys.exit(1)
375                 ugs=os.environ["USERV_GROUP"]
376                 ok=0
377                 for i in ugs.split():
378                         if group==i: ok=1
379                 if not ok:
380                         print("caller not in group %s"%group)
381                         sys.exit(1)
382         else:
383                 if len(av.arg)>3:
384                         print("Too many arguments")
385                         sys.exit(1)
386                 (inputfile, outputfile) = (av.arg + [None]*2)[0:2]
387
388 parse_args()
389
390 # Classes describing possible datatypes in the configuration file
391
392 class basetype:
393         "Common protocol for configuration types."
394         def add(self,obj,w):
395                 complain("%s %s already has property %s defined"%
396                         (obj.type,obj.name,w[0].raw()))
397         def forsites(self,version,copy,fs):
398                 return copy
399
400 class conflist:
401         "A list of some kind of configuration type."
402         def __init__(self,subtype,w):
403                 self.subtype=subtype
404                 self.list=[subtype(w)]
405         def add(self,obj,w):
406                 self.list.append(self.subtype(w))
407         def __str__(self):
408                 return ', '.join(map(str, self.list))
409         def forsites(self,version,copy,fs):
410                 most_recent=self.list[len(self.list)-1]
411                 return most_recent.forsites(version,copy,fs)
412 def listof(subtype):
413         return lambda w: conflist(subtype, w)
414
415 class single_ipaddr (basetype):
416         "An IP address"
417         def __init__(self,w):
418                 self.addr=ipaddress.ip_address(w[1].raw_mark_ok())
419         def __str__(self):
420                 return '"%s"'%self.addr
421
422 class networks (basetype):
423         "A set of IP addresses specified as a list of networks"
424         def __init__(self,w):
425                 self.set=ipaddrset.IPAddressSet()
426                 for i in w[1:]:
427                         x=ipaddress.ip_network(i.raw_mark_ok(),strict=True)
428                         self.set.append([x])
429         def __str__(self):
430                 return ",".join(map((lambda n: '"%s"'%n), self.set.networks()))
431
432 class dhgroup (basetype):
433         "A Diffie-Hellman group"
434         def __init__(self,w):
435                 self.mod=w[1].bignum_16('dh','dh mod')
436                 self.gen=w[2].bignum_16('dh','dh gen')
437         def __str__(self):
438                 return 'diffie-hellman("%s","%s")'%(self.mod,self.gen)
439
440 class hash (basetype):
441         "A choice of hash function"
442         def __init__(self,w):
443                 hname=w[1]
444                 self.ht=hname.raw()
445                 if (self.ht!='md5' and self.ht!='sha1'):
446                         complain("unknown hash type %s"%(self.ht))
447                         self.ht=None
448                 else:
449                         hname.raw_mark_ok()
450         def __str__(self):
451                 return '%s'%(self.ht)
452
453 class email (basetype):
454         "An email address"
455         def __init__(self,w):
456                 self.addr=w[1].email()
457         def __str__(self):
458                 return '<%s>'%(self.addr)
459
460 class boolean (basetype):
461         "A boolean"
462         def __init__(self,w):
463                 v=w[1]
464                 if re.match('[TtYy1]',v.raw()):
465                         self.b=True
466                         v.raw_mark_ok()
467                 elif re.match('[FfNn0]',v.raw()):
468                         self.b=False
469                         v.raw_mark_ok()
470                 else:
471                         complain("invalid boolean value");
472         def __str__(self):
473                 return ['False','True'][self.b]
474
475 class num (basetype):
476         "A decimal number"
477         def __init__(self,w):
478                 self.n=w[1].number(0,0x7fffffff)
479         def __str__(self):
480                 return '%d'%(self.n)
481
482 class serial (basetype):
483         def __init__(self,w):
484                 self.i=w[1].hexid(4,'serial')
485         def __str__(self):
486                 return self.i
487         def forsites(self,version,copy,fs):
488                 if version < 2: return []
489                 return copy
490
491 class address (basetype):
492         "A DNS name and UDP port number"
493         def __init__(self,w):
494                 self.adr=w[1].host()
495                 self.port=w[2].number(1,65536,'port')
496         def __str__(self):
497                 return '"%s"; port %d'%(self.adr,self.port)
498
499 class inpub (basetype):
500         def forsites(self,version,xcopy,fs):
501                 return self.forpub(version,fs)
502
503 class pubkey (inpub):
504         "Some kind of publie key"
505         def __init__(self,w):
506                 self.a=w[1].name('algname')
507                 self.d=w[2].base91();
508         def __str__(self):
509                 return 'make-public("%s","%s")'%(self.a,self.d)
510         def forpub(self,version,fs):
511                 if version < 2: return []
512                 return ['pub', self.a, self.d]
513         def okforonlykey(self,version,fs):
514                 return len(self.forpub(version,fs)) != 0
515
516 class rsakey (pubkey):
517         "An old-style RSA public key"
518         def __init__(self,w):
519                 self.l=w[1].number(0,max['rsa_bits'],'rsa len')
520                 self.e=w[2].bignum_10('rsa','rsa e')
521                 self.n=w[3].bignum_10('rsa','rsa n')
522                 if len(w) >= 5: w[4].email()
523                 self.a='rsa1'
524                 self.d=base91s_encode(b'%d %s %s' %
525                                       (self.l,
526                                        self.e.encode('ascii'),
527                                        self.n.encode('ascii')))
528                 # ^ this allows us to use the pubkey.forsites()
529                 # method for output in versions>=2
530         def __str__(self):
531                 return 'rsa-public("%s","%s")'%(self.e,self.n)
532                 # this specialisation means we can generate files
533                 # compatible with old secnet executables
534         def forpub(self,version,fs):
535                 if version < 2:
536                         if fs.pkg != '00000000': return []
537                         return ['pubkey', str(self.l), self.e, self.n]
538                 return pubkey.forpub(self,version,fs)
539
540 class rsakey_newfmt(rsakey):
541         "An old-style RSA public key in new-style sites format"
542         # This is its own class simply to have its own constructor.
543         def __init__(self,w):
544                 self.a=w[1].name()
545                 assert(self.a == 'rsa1')
546                 self.d=w[2].base91()
547                 try:
548                         w_inner=list(map(Tainted,
549                                         ['X-PUB-RSA1'] +
550                                         base91s_decode(self.d)
551                                         .decode('ascii')
552                                         .split(' ')))
553                 except UnicodeDecodeError:
554                         complain('rsa1 key in new format has bad base91')
555                 #print(repr(w_inner), file=sys.stderr)
556                 rsakey.__init__(self,w_inner)
557
558 class pubkey_group(inpub):
559         "Public key group introducer"
560         # appears in the site's list of keys mixed in with the keys
561         def __init__(self,w,fallback):
562                 self.i=w[1].hexid(4,'pkg-id')
563                 self.fallback=fallback
564         def forpub(self,version,fs):
565                 fs.pkg=self.i
566                 if version < 2: return []
567                 return ['pkgf' if self.fallback else 'pkg', self.i]
568         def okforonlykey(self,version,fs):
569                 self.forpub(version,fs)
570                 return False
571         
572 def somepubkey(w):
573         #print(repr(w), file=sys.stderr)
574         if w[0]=='pubkey':
575                 return rsakey(w)
576         elif w[0]=='pub' and w[1]=='rsa1':
577                 return rsakey_newfmt(w)
578         elif w[0]=='pub':
579                 return pubkey(w)
580         elif w[0]=='pkg':
581                 return pubkey_group(w,False)
582         elif w[0]=='pkgf':
583                 return pubkey_group(w,True)
584         else:
585                 assert(False)
586
587 # Possible properties of configuration nodes
588 keywords={
589  'contact':(email,"Contact address"),
590  'dh':(dhgroup,"Diffie-Hellman group"),
591  'hash':(hash,"Hash function"),
592  'key-lifetime':(num,"Maximum key lifetime (ms)"),
593  'setup-timeout':(num,"Key setup timeout (ms)"),
594  'setup-retries':(num,"Maximum key setup packet retries"),
595  'wait-time':(num,"Time to wait after unsuccessful key setup (ms)"),
596  'renegotiate-time':(num,"Time after key setup to begin renegotiation (ms)"),
597  'restrict-nets':(networks,"Allowable networks"),
598  'networks':(networks,"Claimed networks"),
599  'serial':(serial,"public key set serial"),
600  'pkg':(listof(somepubkey),"start of public key group",'pub'),
601  'pkgf':(listof(somepubkey),"start of fallback public key group",'pub'),
602  'pub':(listof(somepubkey),"new style public site key"),
603  'pubkey':(listof(somepubkey),"Old-style RSA public site key",'pub'),
604  'peer':(single_ipaddr,"Tunnel peer IP address"),
605  'address':(address,"External contact address and port"),
606  'mobile':(boolean,"Site is mobile"),
607 }
608
609 def sp(name,value):
610         "Simply output a property - the default case"
611         return "%s %s;\n"%(name,value)
612
613 # All levels support these properties
614 global_properties={
615         'contact':(lambda name,value:"# Contact email address: %s\n"%(value)),
616         'dh':sp,
617         'hash':sp,
618         'key-lifetime':sp,
619         'setup-timeout':sp,
620         'setup-retries':sp,
621         'wait-time':sp,
622         'renegotiate-time':sp,
623         'restrict-nets':(lambda name,value:"# restrict-nets %s\n"%value),
624 }
625
626 class level:
627         "A level in the configuration hierarchy"
628         depth=0
629         leaf=0
630         allow_properties={}
631         require_properties={}
632         def __init__(self,w):
633                 self.type=w[0].keyword()
634                 self.name=w[1].name()
635                 self.properties={}
636                 self.children={}
637         def indent(self,w,t):
638                 w.write("                 "[:t])
639         def prop_out(self,n):
640                 return self.allow_properties[n](n,str(self.properties[n]))
641         def output_props(self,w,ind):
642                 for i in sorted(self.properties.keys()):
643                         if self.allow_properties[i]:
644                                 self.indent(w,ind)
645                                 w.write("%s"%self.prop_out(i))
646         def kname(self):
647                 return ((self.type[0].upper() if key_prefix else '')
648                         + self.name)
649         def output_data(self,w,path):
650                 ind = 2*len(path)
651                 self.indent(w,ind)
652                 w.write("%s {\n"%(self.kname()))
653                 self.output_props(w,ind+2)
654                 if self.depth==1: w.write("\n");
655                 for k in sorted(self.children.keys()):
656                         c=self.children[k]
657                         c.output_data(w,path+(c,))
658                 self.indent(w,ind)
659                 w.write("};\n")
660
661 class vpnlevel(level):
662         "VPN level in the configuration hierarchy"
663         depth=1
664         leaf=0
665         type="vpn"
666         allow_properties=global_properties.copy()
667         require_properties={
668          'contact':"VPN admin contact address"
669         }
670         def __init__(self,w):
671                 level.__init__(self,w)
672         def output_vpnflat(self,w,path):
673                 "Output flattened list of site names for this VPN"
674                 ind=2*(len(path)+1)
675                 self.indent(w,ind)
676                 w.write("%s {\n"%(self.kname()))
677                 for i in self.children.keys():
678                         self.children[i].output_vpnflat(w,path+(self,))
679                 w.write("\n")
680                 self.indent(w,ind+2)
681                 w.write("all-sites %s;\n"%
682                         ','.join(map(lambda i: i.kname(),
683                                      self.children.values())))
684                 self.indent(w,ind)
685                 w.write("};\n")
686
687 class locationlevel(level):
688         "Location level in the configuration hierarchy"
689         depth=2
690         leaf=0
691         type="location"
692         allow_properties=global_properties.copy()
693         require_properties={
694          'contact':"Location admin contact address",
695         }
696         def __init__(self,w):
697                 level.__init__(self,w)
698                 self.group=w[2].groupname()
699         def output_vpnflat(self,w,path):
700                 ind=2*(len(path)+1)
701                 self.indent(w,ind)
702                 # The "path=path,self=self" abomination below exists because
703                 # Python didn't support nested_scopes until version 2.1
704                 #
705                 #"/"+self.name+"/"+i
706                 w.write("%s %s;\n"%(self.kname(),','.join(
707                         map(lambda x,path=path,self=self:
708                             '/'.join([prefix+"vpn-data"] + list(map(
709                                     lambda i: i.kname(),
710                                     path+(self,x)))),
711                             self.children.values()))))
712
713 class sitelevel(level):
714         "Site level (i.e. a leafnode) in the configuration hierarchy"
715         depth=3
716         leaf=1
717         type="site"
718         allow_properties=global_properties.copy()
719         allow_properties.update({
720          'address':sp,
721          'networks':None,
722          'peer':None,
723          'serial':None,
724          'pkg':None,
725          'pkgf':None,
726          'pub':None,
727          'pubkey':None,
728          'mobile':sp,
729         })
730         require_properties={
731          'dh':"Diffie-Hellman group",
732          'contact':"Site admin contact address",
733          'networks':"Networks claimed by the site",
734          'hash':"hash function",
735          'peer':"Gateway address of the site",
736         }
737         def mangle_name(self):
738                 return self.name.replace('/',',')
739         def pubkeys_path(self):
740                 return pubkeys_dir + '/peer.' + self.mangle_name()
741         def __init__(self,w):
742                 level.__init__(self,w)
743         def output_data(self,w,path):
744                 ind=2*len(path)
745                 np='/'.join(map(lambda i: i.name, path))
746                 self.indent(w,ind)
747                 w.write("%s {\n"%(self.kname()))
748                 self.indent(w,ind+2)
749                 w.write("name \"%s\";\n"%(np,))
750                 self.indent(w,ind+2)
751
752                 pkm = pubkeys_mode()
753                 debugrepr('pkm ',pkm)
754                 pkm.site_start(self.pubkeys_path())
755                 if 'serial' in self.properties:
756                         pkm.site_serial(self.properties['serial'])
757
758                 for k in self.properties["pub"].list:
759                         debugrepr('pubkeys ', k)
760                         pkm.write_key(k)
761
762                 pkm.site_finish(w)
763
764                 self.output_props(w,ind+2)
765                 self.indent(w,ind+2)
766                 w.write("link netlink {\n");
767                 self.indent(w,ind+4)
768                 w.write("routes %s;\n"%str(self.properties["networks"]))
769                 self.indent(w,ind+4)
770                 w.write("ptp-address %s;\n"%str(self.properties["peer"]))
771                 self.indent(w,ind+2)
772                 w.write("};\n")
773                 self.indent(w,ind)
774                 w.write("};\n")
775
776 # Levels in the configuration file
777 # (depth,properties)
778 levels={'vpn':vpnlevel, 'location':locationlevel, 'site':sitelevel}
779
780 def complain(msg):
781         "Complain about a particular input line"
782         moan(("%s line %d: "%(file,line))+msg)
783 def moan(msg):
784         "Complain about something in general"
785         global complaints
786         print(msg);
787         if complaints is None: sys.exit(1)
788         complaints=complaints+1
789
790 class UntaintedRoot():
791         def __init__(self,s): self._s=s
792         def name(self): return self._s
793         def keyword(self): return self._s
794
795 root=level([UntaintedRoot(x) for x in ['root','root']])
796 # All vpns are children of this node
797 obstack=[root]
798 allow_defs=0   # Level above which new definitions are permitted
799
800 def set_property(obj,w):
801         "Set a property on a configuration node"
802         prop=w[0]
803         propname=prop.raw_mark_ok()
804         kw=keywords[propname]
805         if len(kw) >= 3: propname=kw[2] # for aliases
806         if propname in obj.properties:
807                 obj.properties[propname].add(obj,w)
808         else:
809                 obj.properties[propname]=kw[0](w)
810         return obj.properties[propname]
811
812 class FilterState:
813         def __init__(self):
814                 self.reset()
815         def reset(self):
816                 # called when we enter a new node,
817                 # in particular, at the start of each site
818                 self.pkg = '00000000'
819
820 def pline(il,filterstate,allow_include=False):
821         "Process a configuration file line"
822         global allow_defs, obstack, root
823         w=il.rstrip('\n').split()
824         if len(w)==0: return ['']
825         w=list([Tainted(x) for x in w])
826         keyword=w[0]
827         current=obstack[len(obstack)-1]
828         copyout_core=lambda: ' '.join([ww.output() for ww in w])
829         indent='    '*len(obstack)
830         copyout=lambda: [indent + copyout_core() + '\n']
831         if keyword=='end-definitions':
832                 keyword.raw_mark_ok()
833                 allow_defs=sitelevel.depth
834                 obstack=[root]
835                 return copyout()
836         if keyword=='include':
837                 if not allow_include:
838                         complain("include not permitted here")
839                         return []
840                 if len(w) != 2:
841                         complain("include requires one argument")
842                         return []
843                 newfile=os.path.join(os.path.dirname(file),w[1].raw_mark_ok())
844                 # ^ user of "include" is trusted so raw_mark_ok is good
845                 return pfilepath(newfile,allow_include=allow_include)
846         if keyword.raw() in levels:
847                 # We may go up any number of levels, but only down by one
848                 newdepth=levels[keyword.raw_mark_ok()].depth
849                 currentdepth=len(obstack) # actually +1...
850                 if newdepth<=currentdepth:
851                         obstack=obstack[:newdepth]
852                 if newdepth>currentdepth:
853                         complain("May not go from level %d to level %d"%
854                                 (currentdepth-1,newdepth))
855                 # See if it's a new one (and whether that's permitted)
856                 # or an existing one
857                 current=obstack[len(obstack)-1]
858                 tname=w[1].name()
859                 if tname in current.children:
860                         # Not new
861                         current=current.children[tname]
862                         if service and group and current.depth==2:
863                                 if group!=current.group:
864                                         complain("Incorrect group!")
865                                 w[2].groupname()
866                 else:
867                         # New
868                         # Ignore depth check for now
869                         nl=levels[keyword.raw()](w)
870                         if nl.depth<allow_defs:
871                                 complain("New definitions not allowed at "
872                                         "level %d"%nl.depth)
873                                 # we risk crashing if we continue
874                                 sys.exit(1)
875                         current.children[tname]=nl
876                         current=nl
877                 filterstate.reset()
878                 obstack.append(current)
879                 return copyout()
880         if keyword.raw() not in current.allow_properties:
881                 complain("Property %s not allowed at %s level"%
882                         (keyword.raw(),current.type))
883                 return []
884         elif current.depth == vpnlevel.depth < allow_defs:
885                 complain("Not allowed to set VPN properties here")
886                 return []
887         else:
888                 prop=set_property(current,w)
889                 out=[copyout_core()]
890                 out=prop.forsites(output_version,out,filterstate)
891                 if len(out)==0: return [indent + '#', copyout_core(), '\n']
892                 return [indent + ' '.join(out) + '\n']
893
894         complain("unknown keyword '%s'"%(keyword.raw()))
895
896 def pfilepath(pathname,allow_include=False):
897         f=open(pathname)
898         outlines=pfile(pathname,f.readlines(),allow_include=allow_include)
899         f.close()
900         return outlines
901
902 def pfile(name,lines,allow_include=False):
903         "Process a file"
904         global file,line
905         file=name
906         line=0
907         outlines=[]
908         filterstate = FilterState()
909         for i in lines:
910                 line=line+1
911                 if (i[0]=='#'): continue
912                 outlines += pline(i,filterstate,allow_include=allow_include)
913         return outlines
914
915 def outputsites(w):
916         "Output include file for secnet configuration"
917         w.write("# secnet sites file autogenerated by make-secnet-sites "
918                 +"version %s\n"%VERSION)
919         w.write("# %s\n"%time.asctime(time.localtime(time.time())))
920         w.write("# Command line: %s\n\n"%' '.join(sys.argv))
921
922         # Raw VPN data section of file
923         w.write(prefix+"vpn-data {\n")
924         for i in root.children.values():
925                 i.output_data(w,(i,))
926         w.write("};\n")
927
928         # Per-VPN flattened lists
929         w.write(prefix+"vpn {\n")
930         for i in root.children.values():
931                 i.output_vpnflat(w,())
932         w.write("};\n")
933
934         # Flattened list of sites
935         w.write(prefix+"all-sites %s;\n"%",".join(
936                 map(lambda x:"%svpn/%s/all-sites"%(prefix,x.kname()),
937                         root.children.values())))
938
939 line=0
940 file=None
941 complaints=0
942
943 # Sanity check section
944 # Delete nodes where leaf=0 that have no children
945
946 def live(n):
947         "Number of leafnodes below node n"
948         if n.leaf: return 1
949         for i in n.children.keys():
950                 if live(n.children[i]): return 1
951         return 0
952 def delempty(n):
953         "Delete nodes that have no leafnode children"
954         for i in list(n.children.keys()):
955                 delempty(n.children[i])
956                 if not live(n.children[i]):
957                         del n.children[i]
958
959 # Check that all constraints are met (as far as I can tell
960 # restrict-nets/networks/peer are the only special cases)
961
962 def checkconstraints(n,p,ra):
963         new_p=p.copy()
964         new_p.update(n.properties)
965         for i in n.require_properties.keys():
966                 if i not in new_p:
967                         moan("%s %s is missing property %s"%
968                                 (n.type,n.name,i))
969         for i in new_p.keys():
970                 if i not in n.allow_properties:
971                         moan("%s %s has forbidden property %s"%
972                                 (n.type,n.name,i))
973         # Check address range restrictions
974         if "restrict-nets" in n.properties:
975                 new_ra=ra.intersection(n.properties["restrict-nets"].set)
976         else:
977                 new_ra=ra
978         if "networks" in n.properties:
979                 if not n.properties["networks"].set <= new_ra:
980                         moan("%s %s networks out of bounds"%(n.type,n.name))
981                 if "peer" in n.properties:
982                         if not n.properties["networks"].set.contains(
983                                 n.properties["peer"].addr):
984                                 moan("%s %s peer not in networks"%(n.type,n.name))
985         for i in n.children.keys():
986                 checkconstraints(n.children[i],new_p,new_ra)
987
988 if service:
989         headerinput=pfilepath(header,allow_include=True)
990         userinput=sys.stdin.readlines()
991         pfile("user input",userinput)
992 else:
993         if inputfile is None:
994                 pfile("stdin",sys.stdin.readlines())
995         else:
996                 pfilepath(inputfile)
997
998 delempty(root)
999 checkconstraints(root,{},ipaddrset.complete_set())
1000
1001 if complaints>0:
1002         if complaints==1: print("There was 1 problem.")
1003         else: print("There were %d problems."%(complaints))
1004         sys.exit(1)
1005 complaints=None # arranges to crash if we complain later
1006
1007 if service:
1008         # Put the user's input into their group file, and rebuild the main
1009         # sites file
1010         f=open(groupfiledir+"/T"+group.groupname(),'w')
1011         f.write("# Section submitted by user %s, %s\n"%
1012                 (user,time.asctime(time.localtime(time.time()))))
1013         f.write("# Checked by make-secnet-sites version %s\n\n"%VERSION)
1014         for i in userinput: f.write(i)
1015         f.write("\n")
1016         f.close()
1017         os.rename(groupfiledir+"/T"+group.groupname(),
1018                   groupfiledir+"/R"+group.groupname())
1019         f=open(sitesfile+"-tmp",'w')
1020         f.write("# sites file autogenerated by make-secnet-sites\n")
1021         f.write("# generated %s, invoked by %s\n"%
1022                 (time.asctime(time.localtime(time.time())),user))
1023         f.write("# use make-secnet-sites to turn this file into a\n")
1024         f.write("# valid /etc/secnet/sites.conf file\n\n")
1025         for i in headerinput: f.write(i)
1026         files=os.listdir(groupfiledir)
1027         for i in files:
1028                 if i[0]=='R':
1029                         j=open(groupfiledir+"/"+i)
1030                         f.write(j.read())
1031                         j.close()
1032         f.write("# end of sites file\n")
1033         f.close()
1034         os.rename(sitesfile+"-tmp",sitesfile)
1035 else:
1036         if outputfile is None:
1037                 of=sys.stdout
1038         else:
1039                 tmp_outputfile=outputfile+'~tmp~'
1040                 of=open(tmp_outputfile,'w')
1041         outputsites(of)
1042         if outputfile is not None:
1043                 os.rename(tmp_outputfile,outputfile)