chiark / gitweb /
make-secnet-sites: Tolerate missing group in userv sites file
[secnet.git] / make-secnet-sites
1 #! /usr/bin/env python3
2 #
3 # This file is part of secnet.
4 # See README for full list of copyright holders.
5 #
6 # secnet is free software; you can redistribute it and/or modify it
7 # under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 3 of the License, or
9 # (at your option) any later version.
10
11 # secnet is distributed in the hope that it will be useful, but
12 # WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14 # General Public License for more details.
15
16 # You should have received a copy of the GNU General Public License
17 # version 3 along with secnet; if not, see
18 # https://www.gnu.org/licenses/gpl.html.
19
20 """VPN sites file manipulation.
21
22 This program enables VPN site descriptions to be submitted for
23 inclusion in a central database, and allows the resulting database to
24 be turned into a secnet configuration file.
25
26 A database file can be turned into a secnet configuration file simply:
27 make-secnet-sites.py [infile [outfile]]
28
29 It would be wise to run secnet with the "--just-check-config" option
30 before installing the output on a live system.
31
32 The program expects to be invoked via userv to manage the database; it
33 relies on the USERV_USER and USERV_GROUP environment variables. The
34 command line arguments for this invocation are:
35
36 make-secnet-sites.py -u header-filename groupfiles-directory output-file \
37   group
38
39 All but the last argument are expected to be set by userv; the 'group'
40 argument is provided by the user. A suitable userv configuration file
41 fragment is:
42
43 reset
44 no-disconnect-hup
45 no-suppress-args
46 cd ~/secnet/sites-test/
47 execute ~/secnet/make-secnet-sites.py -u vpnheader groupfiles sites
48
49 This program is part of secnet.
50
51 """
52
53 from __future__ import print_function
54 from __future__ import unicode_literals
55 from builtins import int
56
57 import string
58 import time
59 import sys
60 import os
61 import getopt
62 import re
63 import argparse
64 import math
65
66 import ipaddress
67
68 # entry 0 is "near the executable", or maybe from PYTHONPATH=.,
69 # which we don't want to preempt
70 sys.path.insert(1,"/usr/local/share/secnet")
71 sys.path.insert(1,"/usr/share/secnet")
72 import ipaddrset
73 import base91
74
75 from argparseactionnoyes import ActionNoYes
76
77 VERSION="0.1.18"
78
79 max_version = 2
80
81 from sys import version_info
82 if version_info.major == 2:  # for python2
83     import codecs
84     sys.stdin = codecs.getreader('utf-8')(sys.stdin)
85     sys.stdout = codecs.getwriter('utf-8')(sys.stdout)
86     import io
87     open=lambda f,m='r': io.open(f,m,encoding='utf-8')
88
89 max={'rsa_bits':8200,'name':33,'dh_bits':8200,'algname':127}
90
91 def debugrepr(*args):
92         if debug_level > 0:
93                 print(repr(args), file=sys.stderr)
94
95 def base91s_encode(bindata):
96         return base91.encode(bindata).replace('"',"-")
97
98 def base91s_decode(string):
99         return base91.decode(string.replace("-",'"'))
100
101 class Tainted:
102         def __init__(self,s,tline=None,tfile=None):
103                 self._s=s
104                 self._ok=None
105                 self._line=line if tline is None else tline
106                 self._file=file if tfile is None else tfile
107         def __eq__(self,e):
108                 return self._s==e
109         def __ne__(self,e):
110                 # for Python2
111                 return not self.__eq__(e)
112         def __str__(self):
113                 raise RuntimeError('direct use of Tainted value')
114         def __repr__(self):
115                 return 'Tainted(%s)' % repr(self._s)
116
117         def _bad(self,what,why):
118                 assert(self._ok is not True)
119                 self._ok=False
120                 complain('bad parameter: %s: %s' % (what, why))
121                 return False
122
123         def _max_ok(self,what,maxlen):
124                 if len(self._s) > maxlen:
125                         return self._bad(what,'too long (max %d)' % maxlen)
126                 return True
127
128         def _re_ok(self,bad,what,maxlen=None):
129                 if maxlen is None: maxlen=max[what]
130                 self._max_ok(what,maxlen)
131                 if self._ok is False: return False
132                 if bad.search(self._s):
133                         #print(repr(self), file=sys.stderr)
134                         return self._bad(what,'bad syntax')
135                 return True
136
137         def _rtnval(self, is_ok, ifgood, ifbad=''):
138                 if is_ok:
139                         assert(self._ok is not False)
140                         self._ok=True
141                         return ifgood
142                 else:
143                         assert(self._ok is not True)
144                         self._ok=False
145                         return ifbad
146
147         def _rtn(self, is_ok, ifbad=''):
148                 return self._rtnval(is_ok, self._s, ifbad)
149
150         def raw(self):
151                 return self._s
152         def raw_mark_ok(self):
153                 # caller promises to throw if syntax was dangeorus
154                 return self._rtn(True)
155
156         def output(self):
157                 if self._ok is False: return ''
158                 if self._ok is True: return self._s
159                 print('%s:%d: unchecked/unknown additional data "%s"' %
160                       (self._file,self._line,self._s),
161                       file=sys.stderr)
162                 sys.exit(1)
163
164         bad_name=re.compile(r'^[^a-zA-Z]|[^-_0-9a-zA-Z]')
165         # secnet accepts _ at start of names, but we reserve that
166         bad_name_counter=0
167         def name(self,what='name'):
168                 ok=self._re_ok(Tainted.bad_name,what)
169                 return self._rtn(ok,
170                                  '_line%d_%s' % (self._line, id(self)))
171
172         def keyword(self):
173                 ok=self._s in keywords or self._s in levels
174                 if not ok:
175                         complain('unknown keyword %s' % self._s)
176                 return self._rtn(ok)
177
178         bad_hex=re.compile(r'[^0-9a-fA-F]')
179         def bignum_16(self,kind,what):
180                 maxlen=(max[kind+'_bits']+3)/4
181                 ok=self._re_ok(Tainted.bad_hex,what,maxlen)
182                 return self._rtn(ok)
183
184         bad_num=re.compile(r'[^0-9]')
185         def bignum_10(self,kind,what):
186                 maxlen=math.ceil(max[kind+'_bits'] / math.log10(2))
187                 ok=self._re_ok(Tainted.bad_num,what,maxlen)
188                 return self._rtn(ok)
189
190         def number(self,minn,maxx,what='number'):
191                 # not for bignums
192                 ok=self._re_ok(Tainted.bad_num,what,10)
193                 if ok:
194                         v=int(self._s)
195                         if v<minn or v>maxx:
196                                 ok=self._bad(what,'out of range %d..%d'
197                                              % (minn,maxx))
198                 return self._rtnval(ok,v,minn)
199
200         def hexid(self,byteslen,what):
201                 ok=self._re_ok(Tainted.bad_hex,what,byteslen*2)
202                 if ok:
203                         if len(self._s) < byteslen*2:
204                                 ok=self._bad(what,'too short')
205                 return self._rtn(ok,ifbad='00'*byteslen)
206
207         bad_host=re.compile(r'[^-\][_.:0-9a-zA-Z]')
208         # We permit _ so we can refer to special non-host domains
209         # which have A and AAAA RRs.  This is a crude check and we may
210         # still produce config files with syntactically invalid
211         # domains or addresses, but that is OK.
212         def host(self):
213                 ok=self._re_ok(Tainted.bad_host,'host/address',255)
214                 return self._rtn(ok)
215
216         bad_email=re.compile(r'[^-._0-9a-z@!$%^&*=+~/]')
217         # ^ This does not accept all valid email addresses.  That's
218         # not really possible with this input syntax.  It accepts
219         # all ones that don't require quoting anywhere in email
220         # protocols (and also accepts some invalid ones).
221         def email(self):
222                 ok=self._re_ok(Tainted.bad_email,'email address',1023)
223                 return self._rtn(ok)
224
225         bad_groupname=re.compile(r'^[^_A-Za-z]|[^-+_0-9A-Za-z]')
226         def groupname(self):
227                 ok=self._re_ok(Tainted.bad_groupname,'group name',64)
228                 return self._rtn(ok)
229
230         bad_base91=re.compile(r'[^!-~]|[\'\"\\]')
231         def base91(self,what='base91'):
232                 ok=self._re_ok(Tainted.bad_base91,what,4096)
233                 return self._rtn(ok)
234
235 class ArgActionLambda(argparse.Action):
236         def __init__(self, fn, **kwargs):
237                 self.fn=fn
238                 argparse.Action.__init__(self,**kwargs)
239         def __call__(self,ap,ns,values,option_string):
240                 self.fn(values,ns,ap,option_string)
241
242 class PkmBase():
243         def site_start(self,pubkeys_path):
244                 self._pa=pubkeys_path
245                 self._fs = FilterState()
246         def site_serial(self,serial): pass
247         def write_key(self,k): pass
248         def site_finish(self,confw): pass
249
250 class PkmSingle(PkmBase):
251         opt = 'single'
252         help = 'write one public key per site to sites.conf'
253         def site_start(self,pubkeys_path):
254                 PkmBase.site_start(self,pubkeys_path)
255                 self._outk = []
256         def write_key(self,k):
257                 if k.okforonlykey(output_version,self._fs):
258                         self._outk.append(k)
259         def site_finish(self,confw):
260                 if len(self._outk) == 0:
261                         complain("site with no public key");
262                 elif len(self._outk) != 1:
263                         debugrepr('outk ', self._outk)
264                         complain(
265  "site with multiple public keys, without --pubkeys-install (maybe --output-version=1 would help"
266                         )
267                 else:
268                         confw.write("key %s;\n"%str(self._outk[0]))
269
270 class PkmInstall(PkmBase):
271         opt = 'install'
272         help = 'install public keys in public key directory'
273         def site_start(self,pubkeys_path):
274                 PkmBase.site_start(self,pubkeys_path)
275                 self._pw=open(self._pa+'~tmp','w')
276         def site_serial(self,serial):
277                 self._pw.write('serial %s\n' % serial)
278         def write_key(self,k):
279                 wout=k.forpub(output_version,self._fs)
280                 self._pw.write(' '.join(wout))
281                 self._pw.write('\n')
282         def site_finish(self,confw):
283                 self._pw.close()
284                 os.rename(self._pa+'~tmp',self._pa+'~update')
285                 PkmElide.site_finish(self,confw)
286
287 class PkmElide(PkmBase):
288         opt = 'elide'
289         help = 'no public keys in sites.conf output nor in directory'
290         def site_finish(self,confw):
291                 confw.write("peer-keys \"%s\";\n"%self._pa);
292
293 class OpBase():
294         # Base case is reading a sites file from self.inputfilee.
295         # And writing a sites file to self.sitesfile.
296         def check_group(self,group,w):
297                 if len(w) >= 3:
298                         w[2].groupname()
299         def positional_args(self, av):
300                 if len(av.arg)>3:
301                         print("Too many arguments")
302                         sys.exit(1)
303                 (self.inputfile, self.outputfile) = (av.arg + [None]*2)[0:2]
304         def read_in(self):
305                 if self.inputfile is None:
306                         self.inputlines = pfile("stdin",sys.stdin.readlines())
307                 else:
308                         self.inputlines = pfilepath(self.inputfile)
309         def write_out(self):
310                 if self.outputfile is None:
311                         f=sys.stdout
312                 else:
313                         f=open(self.outputfile+"-tmp",'w')
314                 f.write("# sites file autogenerated by make-secnet-sites\n")
315                 self.write_out_heading(f)
316                 f.write("# use make-secnet-sites to turn this file into a\n")
317                 f.write("# valid /etc/secnet/sites.conf file\n\n")
318                 self.write_out_contents(f)
319                 f.write("# end of sites file\n")
320                 if self.outputfile is not None:
321                         f.close()
322                         os.rename(self.outputfile+"-tmp",self.outputfile)
323
324 class OpConf(OpBase):
325         opts = ['--conf']
326         help = 'sites.conf generation mode (default)'
327         def write_out(self):
328                 if self.outputfile is None:
329                         of=sys.stdout
330                 else:
331                         tmp_outputfile=self.outputfile+'~tmp~'
332                         of=open(tmp_outputfile,'w')
333                 outputsites(of)
334                 if self.outputfile is not None:
335                         os.rename(tmp_outputfile,self.outputfile)
336
337 class OpFilter(OpBase):
338         opts = ['--filter']
339         help = 'sites file filtering mode'
340         def positional_arXgs(self, av):
341                 if len(av.arg)!=1:
342                         print("Too many arguments")
343                 (self.inputfile,) = (av.arg + [None])[0:1]
344                 self.outputfile = None
345         def write_out_heading(self,f):
346                 f.write("# --filter --output-version=%d\n"%output_version)
347         def write_out_contents(self,f):
348                 for i in self.inputlines: f.write(i)
349
350 class OpUserv(OpBase):
351         opts = ['--userv','-u']
352         help = 'userv service fragment update mode'
353         def positional_args(self, av):
354                 if len(av.arg)!=4:
355                         print("Wrong number of arguments")
356                         sys.exit(1)
357                 (self.header, self.groupfiledir,
358                  self.outputfile, self.group) = av.arg
359                 self.group = Tainted(self.group,0,'command line')
360                 # untrusted argument from caller
361                 if "USERV_USER" not in os.environ:
362                         print("Environment variable USERV_USER not found")
363                         sys.exit(1)
364                 self.user=os.environ["USERV_USER"]
365                 # Check that group is in USERV_GROUP
366                 if "USERV_GROUP" not in os.environ:
367                         print("Environment variable USERV_GROUP not found")
368                         sys.exit(1)
369                 ugs=os.environ["USERV_GROUP"]
370                 ok=0
371                 for i in ugs.split():
372                         if self.group==i: ok=1
373                 if not ok:
374                         print("caller not in group %s"%self.group.groupname())
375                         sys.exit(1)
376         def check_group(self,group,w):
377                 if group!=self.group: complain("Incorrect group!")
378                 OpBase.check_group(self,group,w)
379         def read_in(self):
380                 self.headerinput=pfilepath(self.header,allow_include=True)
381                 self.userinput=sys.stdin.readlines()
382                 pfile("user input",self.userinput)
383         def write_out(self):
384                 # Put the user's input into their group file, and
385                 # rebuild the main sites file
386                 f=open(self.groupfiledir+"/T"+self.group.groupname(),'w')
387                 f.write("# Section submitted by user %s, %s\n"%
388                         (self.user,time.asctime(time.localtime(time.time()))))
389                 f.write("# Checked by make-secnet-sites version %s\n\n"
390                         %VERSION)
391                 for i in self.userinput: f.write(i)
392                 f.write("\n")
393                 f.close()
394                 os.rename(self.groupfiledir+"/T"+self.group.groupname(),
395                           self.groupfiledir+"/R"+self.group.groupname())
396                 OpBase.write_out(self)
397         def write_out_heading(self,f):
398                 f.write("# generated %s, invoked by %s\n"%
399                         (time.asctime(time.localtime(time.time())),
400                          self.user))
401         def write_out_contents(self,f):
402                 for i in self.headerinput: f.write(i)
403                 files=os.listdir(self.groupfiledir)
404                 for i in files:
405                         if i[0]=='R':
406                                 j=open(self.groupfiledir+"/"+i)
407                                 f.write(j.read())
408                                 j.close()
409
410 def parse_args():
411         global opmode
412         global prefix
413         global key_prefix
414         global debug_level
415         global output_version
416         global pubkeys_dir
417         global pubkeys_mode
418
419         ap = argparse.ArgumentParser(description='process secnet sites files')
420         def add_opmode(how):
421                 ap.add_argument(*how().opts, action=ArgActionLambda,
422                         nargs=0,
423                         fn=(lambda v,ns,*x: setattr(ns,'opmode',how)),
424                         help=how().help)
425         add_opmode(OpConf)
426         add_opmode(OpFilter)
427         add_opmode(OpUserv)
428         ap.add_argument('--conf-key-prefix', action=ActionNoYes,
429                         default=True,
430                  help='prefix conf file key names derived from sites data')
431         def add_pkm(how):
432                 ap.add_argument('--pubkeys-'+how().opt, action=ArgActionLambda,
433                         nargs=0,
434                         fn=(lambda v,ns,*x: setattr(ns,'pkm',how)),
435                         help=how().help)
436         add_pkm(PkmInstall)
437         add_pkm(PkmSingle)
438         add_pkm(PkmElide)
439         ap.add_argument('--pubkeys-dir',  nargs=1,
440                         help='public key directory',
441                         default=['/var/lib/secnet/pubkeys'])
442         ap.add_argument('--output-version', nargs=1, type=int,
443                         help='sites file output version',
444                         default=[max_version])
445         ap.add_argument('--prefix', '-P', nargs=1,
446                         help='set prefix')
447         ap.add_argument('--debug', '-D', action='count', default=0)
448         ap.add_argument('arg',nargs=argparse.REMAINDER)
449         av = ap.parse_args()
450         debug_level = av.debug
451         debugrepr('av',av)
452         opmode = getattr(av,'opmode',OpConf)()
453         prefix = '' if av.prefix is None else av.prefix[0]
454         key_prefix = av.conf_key_prefix
455         output_version = av.output_version[0]
456         pubkeys_dir = av.pubkeys_dir[0]
457         pubkeys_mode = getattr(av,'pkm',PkmSingle)
458         opmode.positional_args(av)
459
460 parse_args()
461
462 # Classes describing possible datatypes in the configuration file
463
464 class basetype:
465         "Common protocol for configuration types."
466         def add(self,obj,w):
467                 complain("%s %s already has property %s defined"%
468                         (obj.type,obj.name,w[0].raw()))
469         def forsites(self,version,copy,fs):
470                 return copy
471
472 class conflist:
473         "A list of some kind of configuration type."
474         def __init__(self,subtype,w):
475                 self.subtype=subtype
476                 self.list=[subtype(w)]
477         def add(self,obj,w):
478                 self.list.append(self.subtype(w))
479         def __str__(self):
480                 return ', '.join(map(str, self.list))
481         def forsites(self,version,copy,fs):
482                 most_recent=self.list[len(self.list)-1]
483                 return most_recent.forsites(version,copy,fs)
484 def listof(subtype):
485         return lambda w: conflist(subtype, w)
486
487 class single_ipaddr (basetype):
488         "An IP address"
489         def __init__(self,w):
490                 self.addr=ipaddress.ip_address(w[1].raw_mark_ok())
491         def __str__(self):
492                 return '"%s"'%self.addr
493
494 class networks (basetype):
495         "A set of IP addresses specified as a list of networks"
496         def __init__(self,w):
497                 self.set=ipaddrset.IPAddressSet()
498                 for i in w[1:]:
499                         x=ipaddress.ip_network(i.raw_mark_ok(),strict=True)
500                         self.set.append([x])
501         def __str__(self):
502                 return ",".join(map((lambda n: '"%s"'%n), self.set.networks()))
503
504 class dhgroup (basetype):
505         "A Diffie-Hellman group"
506         def __init__(self,w):
507                 self.mod=w[1].bignum_16('dh','dh mod')
508                 self.gen=w[2].bignum_16('dh','dh gen')
509         def __str__(self):
510                 return 'diffie-hellman("%s","%s")'%(self.mod,self.gen)
511
512 class hash (basetype):
513         "A choice of hash function"
514         def __init__(self,w):
515                 hname=w[1]
516                 self.ht=hname.raw()
517                 if (self.ht!='md5' and self.ht!='sha1'):
518                         complain("unknown hash type %s"%(self.ht))
519                         self.ht=None
520                 else:
521                         hname.raw_mark_ok()
522         def __str__(self):
523                 return '%s'%(self.ht)
524
525 class email (basetype):
526         "An email address"
527         def __init__(self,w):
528                 self.addr=w[1].email()
529         def __str__(self):
530                 return '<%s>'%(self.addr)
531
532 class boolean (basetype):
533         "A boolean"
534         def __init__(self,w):
535                 v=w[1]
536                 if re.match('[TtYy1]',v.raw()):
537                         self.b=True
538                         v.raw_mark_ok()
539                 elif re.match('[FfNn0]',v.raw()):
540                         self.b=False
541                         v.raw_mark_ok()
542                 else:
543                         complain("invalid boolean value");
544         def __str__(self):
545                 return ['False','True'][self.b]
546
547 class num (basetype):
548         "A decimal number"
549         def __init__(self,w):
550                 self.n=w[1].number(0,0x7fffffff)
551         def __str__(self):
552                 return '%d'%(self.n)
553
554 class serial (basetype):
555         def __init__(self,w):
556                 self.i=w[1].hexid(4,'serial')
557         def __str__(self):
558                 return self.i
559         def forsites(self,version,copy,fs):
560                 if version < 2: return []
561                 return copy
562
563 class address (basetype):
564         "A DNS name and UDP port number"
565         def __init__(self,w):
566                 self.adr=w[1].host()
567                 self.port=w[2].number(1,65536,'port')
568         def __str__(self):
569                 return '"%s"; port %d'%(self.adr,self.port)
570
571 class inpub (basetype):
572         def forsites(self,version,xcopy,fs):
573                 return self.forpub(version,fs)
574
575 class pubkey (inpub):
576         "Some kind of publie key"
577         def __init__(self,w):
578                 self.a=w[1].name('algname')
579                 self.d=w[2].base91();
580         def __str__(self):
581                 return 'make-public("%s","%s")'%(self.a,self.d)
582         def forpub(self,version,fs):
583                 if version < 2: return []
584                 return ['pub', self.a, self.d]
585         def okforonlykey(self,version,fs):
586                 return len(self.forpub(version,fs)) != 0
587
588 class rsakey (pubkey):
589         "An old-style RSA public key"
590         def __init__(self,w):
591                 self.l=w[1].number(0,max['rsa_bits'],'rsa len')
592                 self.e=w[2].bignum_10('rsa','rsa e')
593                 self.n=w[3].bignum_10('rsa','rsa n')
594                 if len(w) >= 5: w[4].email()
595                 self.a='rsa1'
596                 self.d=base91s_encode(('%d %s %s' %
597                                        (self.l,
598                                         self.e,
599                                         self.n)).encode('ascii'))
600                 # ^ this allows us to use the pubkey.forsites()
601                 # method for output in versions>=2
602         def __str__(self):
603                 return 'rsa-public("%s","%s")'%(self.e,self.n)
604                 # this specialisation means we can generate files
605                 # compatible with old secnet executables
606         def forpub(self,version,fs):
607                 if version < 2:
608                         if fs.pkg != '00000000': return []
609                         return ['pubkey', str(self.l), self.e, self.n]
610                 return pubkey.forpub(self,version,fs)
611
612 class rsakey_newfmt(rsakey):
613         "An old-style RSA public key in new-style sites format"
614         # This is its own class simply to have its own constructor.
615         def __init__(self,w):
616                 self.a=w[1].name()
617                 assert(self.a == 'rsa1')
618                 self.d=w[2].base91()
619                 try:
620                         w_inner=list(map(Tainted,
621                                         ['X-PUB-RSA1'] +
622                                         base91s_decode(self.d)
623                                         .decode('ascii')
624                                         .split(' ')))
625                 except UnicodeDecodeError:
626                         complain('rsa1 key in new format has bad base91')
627                 #print(repr(w_inner), file=sys.stderr)
628                 rsakey.__init__(self,w_inner)
629
630 class pubkey_group(inpub):
631         "Public key group introducer"
632         # appears in the site's list of keys mixed in with the keys
633         def __init__(self,w,fallback):
634                 self.i=w[1].hexid(4,'pkg-id')
635                 self.fallback=fallback
636         def forpub(self,version,fs):
637                 fs.pkg=self.i
638                 if version < 2: return []
639                 return ['pkgf' if self.fallback else 'pkg', self.i]
640         def okforonlykey(self,version,fs):
641                 self.forpub(version,fs)
642                 return False
643         
644 def somepubkey(w):
645         #print(repr(w), file=sys.stderr)
646         if w[0]=='pubkey':
647                 return rsakey(w)
648         elif w[0]=='pub' and w[1]=='rsa1':
649                 return rsakey_newfmt(w)
650         elif w[0]=='pub':
651                 return pubkey(w)
652         elif w[0]=='pkg':
653                 return pubkey_group(w,False)
654         elif w[0]=='pkgf':
655                 return pubkey_group(w,True)
656         else:
657                 assert(False)
658
659 # Possible properties of configuration nodes
660 keywords={
661  'contact':(email,"Contact address"),
662  'dh':(dhgroup,"Diffie-Hellman group"),
663  'hash':(hash,"Hash function"),
664  'key-lifetime':(num,"Maximum key lifetime (ms)"),
665  'setup-timeout':(num,"Key setup timeout (ms)"),
666  'setup-retries':(num,"Maximum key setup packet retries"),
667  'wait-time':(num,"Time to wait after unsuccessful key setup (ms)"),
668  'renegotiate-time':(num,"Time after key setup to begin renegotiation (ms)"),
669  'restrict-nets':(networks,"Allowable networks"),
670  'networks':(networks,"Claimed networks"),
671  'serial':(serial,"public key set serial"),
672  'pkg':(listof(somepubkey),"start of public key group",'pub'),
673  'pkgf':(listof(somepubkey),"start of fallback public key group",'pub'),
674  'pub':(listof(somepubkey),"new style public site key"),
675  'pubkey':(listof(somepubkey),"Old-style RSA public site key",'pub'),
676  'peer':(single_ipaddr,"Tunnel peer IP address"),
677  'address':(address,"External contact address and port"),
678  'mobile':(boolean,"Site is mobile"),
679 }
680
681 def sp(name,value):
682         "Simply output a property - the default case"
683         return "%s %s;\n"%(name,value)
684
685 # All levels support these properties
686 global_properties={
687         'contact':(lambda name,value:"# Contact email address: %s\n"%(value)),
688         'dh':sp,
689         'hash':sp,
690         'key-lifetime':sp,
691         'setup-timeout':sp,
692         'setup-retries':sp,
693         'wait-time':sp,
694         'renegotiate-time':sp,
695         'restrict-nets':(lambda name,value:"# restrict-nets %s\n"%value),
696 }
697
698 class level:
699         "A level in the configuration hierarchy"
700         depth=0
701         leaf=0
702         allow_properties={}
703         require_properties={}
704         def __init__(self,w):
705                 self.type=w[0].keyword()
706                 self.name=w[1].name()
707                 self.properties={}
708                 self.children={}
709         def indent(self,w,t):
710                 w.write("                 "[:t])
711         def prop_out(self,n):
712                 return self.allow_properties[n](n,str(self.properties[n]))
713         def output_props(self,w,ind):
714                 for i in sorted(self.properties.keys()):
715                         if self.allow_properties[i]:
716                                 self.indent(w,ind)
717                                 w.write("%s"%self.prop_out(i))
718         def kname(self):
719                 return ((self.type[0].upper() if key_prefix else '')
720                         + self.name)
721         def output_data(self,w,path):
722                 ind = 2*len(path)
723                 self.indent(w,ind)
724                 w.write("%s {\n"%(self.kname()))
725                 self.output_props(w,ind+2)
726                 if self.depth==1: w.write("\n");
727                 for k in sorted(self.children.keys()):
728                         c=self.children[k]
729                         c.output_data(w,path+(c,))
730                 self.indent(w,ind)
731                 w.write("};\n")
732
733 class vpnlevel(level):
734         "VPN level in the configuration hierarchy"
735         depth=1
736         leaf=0
737         type="vpn"
738         allow_properties=global_properties.copy()
739         require_properties={
740          'contact':"VPN admin contact address"
741         }
742         def __init__(self,w):
743                 level.__init__(self,w)
744         def output_vpnflat(self,w,path):
745                 "Output flattened list of site names for this VPN"
746                 ind=2*(len(path)+1)
747                 self.indent(w,ind)
748                 w.write("%s {\n"%(self.kname()))
749                 for i in self.children.keys():
750                         self.children[i].output_vpnflat(w,path+(self,))
751                 w.write("\n")
752                 self.indent(w,ind+2)
753                 w.write("all-sites %s;\n"%
754                         ','.join(map(lambda i: i.kname(),
755                                      self.children.values())))
756                 self.indent(w,ind)
757                 w.write("};\n")
758
759 class locationlevel(level):
760         "Location level in the configuration hierarchy"
761         depth=2
762         leaf=0
763         type="location"
764         allow_properties=global_properties.copy()
765         require_properties={
766          'contact':"Location admin contact address",
767         }
768         def __init__(self,w):
769                 level.__init__(self,w)
770                 self.group=w[2].groupname()
771         def output_vpnflat(self,w,path):
772                 ind=2*(len(path)+1)
773                 self.indent(w,ind)
774                 # The "path=path,self=self" abomination below exists because
775                 # Python didn't support nested_scopes until version 2.1
776                 #
777                 #"/"+self.name+"/"+i
778                 w.write("%s %s;\n"%(self.kname(),','.join(
779                         map(lambda x,path=path,self=self:
780                             '/'.join([prefix+"vpn-data"] + list(map(
781                                     lambda i: i.kname(),
782                                     path+(self,x)))),
783                             self.children.values()))))
784
785 class sitelevel(level):
786         "Site level (i.e. a leafnode) in the configuration hierarchy"
787         depth=3
788         leaf=1
789         type="site"
790         allow_properties=global_properties.copy()
791         allow_properties.update({
792          'address':sp,
793          'networks':None,
794          'peer':None,
795          'serial':None,
796          'pkg':None,
797          'pkgf':None,
798          'pub':None,
799          'pubkey':None,
800          'mobile':sp,
801         })
802         require_properties={
803          'dh':"Diffie-Hellman group",
804          'contact':"Site admin contact address",
805          'networks':"Networks claimed by the site",
806          'hash':"hash function",
807          'peer':"Gateway address of the site",
808         }
809         def mangle_name(self):
810                 return self.name.replace('/',',')
811         def pubkeys_path(self):
812                 return pubkeys_dir + '/peer.' + self.mangle_name()
813         def __init__(self,w):
814                 level.__init__(self,w)
815         def output_data(self,w,path):
816                 ind=2*len(path)
817                 np='/'.join(map(lambda i: i.name, path))
818                 self.indent(w,ind)
819                 w.write("%s {\n"%(self.kname()))
820                 self.indent(w,ind+2)
821                 w.write("name \"%s\";\n"%(np,))
822                 self.indent(w,ind+2)
823
824                 pkm = pubkeys_mode()
825                 debugrepr('pkm ',pkm)
826                 pkm.site_start(self.pubkeys_path())
827                 if 'serial' in self.properties:
828                         pkm.site_serial(self.properties['serial'])
829
830                 for k in self.properties["pub"].list:
831                         debugrepr('pubkeys ', k)
832                         pkm.write_key(k)
833
834                 pkm.site_finish(w)
835
836                 self.output_props(w,ind+2)
837                 self.indent(w,ind+2)
838                 w.write("link netlink {\n");
839                 self.indent(w,ind+4)
840                 w.write("routes %s;\n"%str(self.properties["networks"]))
841                 self.indent(w,ind+4)
842                 w.write("ptp-address %s;\n"%str(self.properties["peer"]))
843                 self.indent(w,ind+2)
844                 w.write("};\n")
845                 self.indent(w,ind)
846                 w.write("};\n")
847
848 # Levels in the configuration file
849 # (depth,properties)
850 levels={'vpn':vpnlevel, 'location':locationlevel, 'site':sitelevel}
851
852 def complain(msg):
853         "Complain about a particular input line"
854         moan(("%s line %d: "%(file,line))+msg)
855 def moan(msg):
856         "Complain about something in general"
857         global complaints
858         print(msg);
859         if complaints is None: sys.exit(1)
860         complaints=complaints+1
861
862 class UntaintedRoot():
863         def __init__(self,s): self._s=s
864         def name(self): return self._s
865         def keyword(self): return self._s
866
867 root=level([UntaintedRoot(x) for x in ['root','root']])
868 # All vpns are children of this node
869 obstack=[root]
870 allow_defs=0   # Level above which new definitions are permitted
871
872 def set_property(obj,w):
873         "Set a property on a configuration node"
874         prop=w[0]
875         propname=prop.raw_mark_ok()
876         kw=keywords[propname]
877         if len(kw) >= 3: propname=kw[2] # for aliases
878         if propname in obj.properties:
879                 obj.properties[propname].add(obj,w)
880         else:
881                 obj.properties[propname]=kw[0](w)
882         return obj.properties[propname]
883
884 class FilterState:
885         def __init__(self):
886                 self.reset()
887         def reset(self):
888                 # called when we enter a new node,
889                 # in particular, at the start of each site
890                 self.pkg = '00000000'
891
892 def pline(il,filterstate,allow_include=False):
893         "Process a configuration file line"
894         global allow_defs, obstack, root
895         w=il.rstrip('\n').split()
896         if len(w)==0: return ['']
897         w=list([Tainted(x) for x in w])
898         keyword=w[0]
899         current=obstack[len(obstack)-1]
900         copyout_core=lambda: ' '.join([ww.output() for ww in w])
901         indent='    '*len(obstack)
902         copyout=lambda: [indent + copyout_core() + '\n']
903         if keyword=='end-definitions':
904                 keyword.raw_mark_ok()
905                 allow_defs=sitelevel.depth
906                 obstack=[root]
907                 return copyout()
908         if keyword=='include':
909                 if not allow_include:
910                         complain("include not permitted here")
911                         return []
912                 if len(w) != 2:
913                         complain("include requires one argument")
914                         return []
915                 newfile=os.path.join(os.path.dirname(file),w[1].raw_mark_ok())
916                 # ^ user of "include" is trusted so raw_mark_ok is good
917                 return pfilepath(newfile,allow_include=allow_include)
918         if keyword.raw() in levels:
919                 # We may go up any number of levels, but only down by one
920                 newdepth=levels[keyword.raw_mark_ok()].depth
921                 currentdepth=len(obstack) # actually +1...
922                 if newdepth<=currentdepth:
923                         obstack=obstack[:newdepth]
924                 if newdepth>currentdepth:
925                         complain("May not go from level %d to level %d"%
926                                 (currentdepth-1,newdepth))
927                 # See if it's a new one (and whether that's permitted)
928                 # or an existing one
929                 current=obstack[len(obstack)-1]
930                 tname=w[1].name()
931                 if tname in current.children:
932                         # Not new
933                         current=current.children[tname]
934                         if current.depth==2:
935                                 opmode.check_group(current.group, w)
936                 else:
937                         # New
938                         # Ignore depth check for now
939                         nl=levels[keyword.raw()](w)
940                         if nl.depth<allow_defs:
941                                 complain("New definitions not allowed at "
942                                         "level %d"%nl.depth)
943                                 # we risk crashing if we continue
944                                 sys.exit(1)
945                         current.children[tname]=nl
946                         current=nl
947                 filterstate.reset()
948                 obstack.append(current)
949                 return copyout()
950         if keyword.raw() not in current.allow_properties:
951                 complain("Property %s not allowed at %s level"%
952                         (keyword.raw(),current.type))
953                 return []
954         elif current.depth == vpnlevel.depth < allow_defs:
955                 complain("Not allowed to set VPN properties here")
956                 return []
957         else:
958                 prop=set_property(current,w)
959                 out=[copyout_core()]
960                 out=prop.forsites(output_version,out,filterstate)
961                 if len(out)==0: return [indent + '#', copyout_core(), '\n']
962                 return [indent + ' '.join(out) + '\n']
963
964         complain("unknown keyword '%s'"%(keyword.raw()))
965
966 def pfilepath(pathname,allow_include=False):
967         f=open(pathname)
968         outlines=pfile(pathname,f.readlines(),allow_include=allow_include)
969         f.close()
970         return outlines
971
972 def pfile(name,lines,allow_include=False):
973         "Process a file"
974         global file,line
975         file=name
976         line=0
977         outlines=[]
978         filterstate = FilterState()
979         for i in lines:
980                 line=line+1
981                 if (i[0]=='#'): continue
982                 outlines += pline(i,filterstate,allow_include=allow_include)
983         return outlines
984
985 def outputsites(w):
986         "Output include file for secnet configuration"
987         w.write("# secnet sites file autogenerated by make-secnet-sites "
988                 +"version %s\n"%VERSION)
989         w.write("# %s\n"%time.asctime(time.localtime(time.time())))
990         w.write("# Command line: %s\n\n"%' '.join(sys.argv))
991
992         # Raw VPN data section of file
993         w.write(prefix+"vpn-data {\n")
994         for i in root.children.values():
995                 i.output_data(w,(i,))
996         w.write("};\n")
997
998         # Per-VPN flattened lists
999         w.write(prefix+"vpn {\n")
1000         for i in root.children.values():
1001                 i.output_vpnflat(w,())
1002         w.write("};\n")
1003
1004         # Flattened list of sites
1005         w.write(prefix+"all-sites %s;\n"%",".join(
1006                 map(lambda x:"%svpn/%s/all-sites"%(prefix,x.kname()),
1007                         root.children.values())))
1008
1009 line=0
1010 file=None
1011 complaints=0
1012
1013 # Sanity check section
1014 # Delete nodes where leaf=0 that have no children
1015
1016 def live(n):
1017         "Number of leafnodes below node n"
1018         if n.leaf: return 1
1019         for i in n.children.keys():
1020                 if live(n.children[i]): return 1
1021         return 0
1022 def delempty(n):
1023         "Delete nodes that have no leafnode children"
1024         for i in list(n.children.keys()):
1025                 delempty(n.children[i])
1026                 if not live(n.children[i]):
1027                         del n.children[i]
1028
1029 # Check that all constraints are met (as far as I can tell
1030 # restrict-nets/networks/peer are the only special cases)
1031
1032 def checkconstraints(n,p,ra):
1033         new_p=p.copy()
1034         new_p.update(n.properties)
1035         for i in n.require_properties.keys():
1036                 if i not in new_p:
1037                         moan("%s %s is missing property %s"%
1038                                 (n.type,n.name,i))
1039         for i in new_p.keys():
1040                 if i not in n.allow_properties:
1041                         moan("%s %s has forbidden property %s"%
1042                                 (n.type,n.name,i))
1043         # Check address range restrictions
1044         if "restrict-nets" in n.properties:
1045                 new_ra=ra.intersection(n.properties["restrict-nets"].set)
1046         else:
1047                 new_ra=ra
1048         if "networks" in n.properties:
1049                 if not n.properties["networks"].set <= new_ra:
1050                         moan("%s %s networks out of bounds"%(n.type,n.name))
1051                 if "peer" in n.properties:
1052                         if not n.properties["networks"].set.contains(
1053                                 n.properties["peer"].addr):
1054                                 moan("%s %s peer not in networks"%(n.type,n.name))
1055         for i in n.children.keys():
1056                 checkconstraints(n.children[i],new_p,new_ra)
1057
1058 opmode.read_in()
1059
1060 delempty(root)
1061 checkconstraints(root,{},ipaddrset.complete_set())
1062
1063 if complaints>0:
1064         if complaints==1: print("There was 1 problem.")
1065         else: print("There were %d problems."%(complaints))
1066         sys.exit(1)
1067 complaints=None # arranges to crash if we complain later
1068
1069 opmode.write_out()