chiark / gitweb /
make-secnet-sites: Move sites file writing into OpBase
[secnet.git] / make-secnet-sites
1 #! /usr/bin/env python3
2 #
3 # This file is part of secnet.
4 # See README for full list of copyright holders.
5 #
6 # secnet is free software; you can redistribute it and/or modify it
7 # under the terms of the GNU General Public License as published by
8 # the Free Software Foundation; either version 3 of the License, or
9 # (at your option) any later version.
10
11 # secnet is distributed in the hope that it will be useful, but
12 # WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14 # General Public License for more details.
15
16 # You should have received a copy of the GNU General Public License
17 # version 3 along with secnet; if not, see
18 # https://www.gnu.org/licenses/gpl.html.
19
20 """VPN sites file manipulation.
21
22 This program enables VPN site descriptions to be submitted for
23 inclusion in a central database, and allows the resulting database to
24 be turned into a secnet configuration file.
25
26 A database file can be turned into a secnet configuration file simply:
27 make-secnet-sites.py [infile [outfile]]
28
29 It would be wise to run secnet with the "--just-check-config" option
30 before installing the output on a live system.
31
32 The program expects to be invoked via userv to manage the database; it
33 relies on the USERV_USER and USERV_GROUP environment variables. The
34 command line arguments for this invocation are:
35
36 make-secnet-sites.py -u header-filename groupfiles-directory output-file \
37   group
38
39 All but the last argument are expected to be set by userv; the 'group'
40 argument is provided by the user. A suitable userv configuration file
41 fragment is:
42
43 reset
44 no-disconnect-hup
45 no-suppress-args
46 cd ~/secnet/sites-test/
47 execute ~/secnet/make-secnet-sites.py -u vpnheader groupfiles sites
48
49 This program is part of secnet.
50
51 """
52
53 from __future__ import print_function
54 from __future__ import unicode_literals
55 from builtins import int
56
57 import string
58 import time
59 import sys
60 import os
61 import getopt
62 import re
63 import argparse
64 import math
65
66 import ipaddress
67
68 # entry 0 is "near the executable", or maybe from PYTHONPATH=.,
69 # which we don't want to preempt
70 sys.path.insert(1,"/usr/local/share/secnet")
71 sys.path.insert(1,"/usr/share/secnet")
72 import ipaddrset
73 import base91
74
75 from argparseactionnoyes import ActionNoYes
76
77 VERSION="0.1.18"
78
79 max_version = 2
80
81 from sys import version_info
82 if version_info.major == 2:  # for python2
83     import codecs
84     sys.stdin = codecs.getreader('utf-8')(sys.stdin)
85     sys.stdout = codecs.getwriter('utf-8')(sys.stdout)
86     import io
87     open=lambda f,m='r': io.open(f,m,encoding='utf-8')
88
89 max={'rsa_bits':8200,'name':33,'dh_bits':8200,'algname':127}
90
91 def debugrepr(*args):
92         if debug_level > 0:
93                 print(repr(args), file=sys.stderr)
94
95 def base91s_encode(bindata):
96         return base91.encode(bindata).replace('"',"-")
97
98 def base91s_decode(string):
99         return base91.decode(string.replace("-",'"'))
100
101 class Tainted:
102         def __init__(self,s,tline=None,tfile=None):
103                 self._s=s
104                 self._ok=None
105                 self._line=line if tline is None else tline
106                 self._file=file if tfile is None else tfile
107         def __eq__(self,e):
108                 return self._s==e
109         def __ne__(self,e):
110                 # for Python2
111                 return not self.__eq__(e)
112         def __str__(self):
113                 raise RuntimeError('direct use of Tainted value')
114         def __repr__(self):
115                 return 'Tainted(%s)' % repr(self._s)
116
117         def _bad(self,what,why):
118                 assert(self._ok is not True)
119                 self._ok=False
120                 complain('bad parameter: %s: %s' % (what, why))
121                 return False
122
123         def _max_ok(self,what,maxlen):
124                 if len(self._s) > maxlen:
125                         return self._bad(what,'too long (max %d)' % maxlen)
126                 return True
127
128         def _re_ok(self,bad,what,maxlen=None):
129                 if maxlen is None: maxlen=max[what]
130                 self._max_ok(what,maxlen)
131                 if self._ok is False: return False
132                 if bad.search(self._s):
133                         #print(repr(self), file=sys.stderr)
134                         return self._bad(what,'bad syntax')
135                 return True
136
137         def _rtnval(self, is_ok, ifgood, ifbad=''):
138                 if is_ok:
139                         assert(self._ok is not False)
140                         self._ok=True
141                         return ifgood
142                 else:
143                         assert(self._ok is not True)
144                         self._ok=False
145                         return ifbad
146
147         def _rtn(self, is_ok, ifbad=''):
148                 return self._rtnval(is_ok, self._s, ifbad)
149
150         def raw(self):
151                 return self._s
152         def raw_mark_ok(self):
153                 # caller promises to throw if syntax was dangeorus
154                 return self._rtn(True)
155
156         def output(self):
157                 if self._ok is False: return ''
158                 if self._ok is True: return self._s
159                 print('%s:%d: unchecked/unknown additional data "%s"' %
160                       (self._file,self._line,self._s),
161                       file=sys.stderr)
162                 sys.exit(1)
163
164         bad_name=re.compile(r'^[^a-zA-Z]|[^-_0-9a-zA-Z]')
165         # secnet accepts _ at start of names, but we reserve that
166         bad_name_counter=0
167         def name(self,what='name'):
168                 ok=self._re_ok(Tainted.bad_name,what)
169                 return self._rtn(ok,
170                                  '_line%d_%s' % (self._line, id(self)))
171
172         def keyword(self):
173                 ok=self._s in keywords or self._s in levels
174                 if not ok:
175                         complain('unknown keyword %s' % self._s)
176                 return self._rtn(ok)
177
178         bad_hex=re.compile(r'[^0-9a-fA-F]')
179         def bignum_16(self,kind,what):
180                 maxlen=(max[kind+'_bits']+3)/4
181                 ok=self._re_ok(Tainted.bad_hex,what,maxlen)
182                 return self._rtn(ok)
183
184         bad_num=re.compile(r'[^0-9]')
185         def bignum_10(self,kind,what):
186                 maxlen=math.ceil(max[kind+'_bits'] / math.log10(2))
187                 ok=self._re_ok(Tainted.bad_num,what,maxlen)
188                 return self._rtn(ok)
189
190         def number(self,minn,maxx,what='number'):
191                 # not for bignums
192                 ok=self._re_ok(Tainted.bad_num,what,10)
193                 if ok:
194                         v=int(self._s)
195                         if v<minn or v>maxx:
196                                 ok=self._bad(what,'out of range %d..%d'
197                                              % (minn,maxx))
198                 return self._rtnval(ok,v,minn)
199
200         def hexid(self,byteslen,what):
201                 ok=self._re_ok(Tainted.bad_hex,what,byteslen*2)
202                 if ok:
203                         if len(self._s) < byteslen*2:
204                                 ok=self._bad(what,'too short')
205                 return self._rtn(ok,ifbad='00'*byteslen)
206
207         bad_host=re.compile(r'[^-\][_.:0-9a-zA-Z]')
208         # We permit _ so we can refer to special non-host domains
209         # which have A and AAAA RRs.  This is a crude check and we may
210         # still produce config files with syntactically invalid
211         # domains or addresses, but that is OK.
212         def host(self):
213                 ok=self._re_ok(Tainted.bad_host,'host/address',255)
214                 return self._rtn(ok)
215
216         bad_email=re.compile(r'[^-._0-9a-z@!$%^&*=+~/]')
217         # ^ This does not accept all valid email addresses.  That's
218         # not really possible with this input syntax.  It accepts
219         # all ones that don't require quoting anywhere in email
220         # protocols (and also accepts some invalid ones).
221         def email(self):
222                 ok=self._re_ok(Tainted.bad_email,'email address',1023)
223                 return self._rtn(ok)
224
225         bad_groupname=re.compile(r'^[^_A-Za-z]|[^-+_0-9A-Za-z]')
226         def groupname(self):
227                 ok=self._re_ok(Tainted.bad_groupname,'group name',64)
228                 return self._rtn(ok)
229
230         bad_base91=re.compile(r'[^!-~]|[\'\"\\]')
231         def base91(self,what='base91'):
232                 ok=self._re_ok(Tainted.bad_base91,what,4096)
233                 return self._rtn(ok)
234
235 class ArgActionLambda(argparse.Action):
236         def __init__(self, fn, **kwargs):
237                 self.fn=fn
238                 argparse.Action.__init__(self,**kwargs)
239         def __call__(self,ap,ns,values,option_string):
240                 self.fn(values,ns,ap,option_string)
241
242 class PkmBase():
243         def site_start(self,pubkeys_path):
244                 self._pa=pubkeys_path
245                 self._fs = FilterState()
246         def site_serial(self,serial): pass
247         def write_key(self,k): pass
248         def site_finish(self,confw): pass
249
250 class PkmSingle(PkmBase):
251         opt = 'single'
252         help = 'write one public key per site to sites.conf'
253         def site_start(self,pubkeys_path):
254                 PkmBase.site_start(self,pubkeys_path)
255                 self._outk = []
256         def write_key(self,k):
257                 if k.okforonlykey(output_version,self._fs):
258                         self._outk.append(k)
259         def site_finish(self,confw):
260                 if len(self._outk) == 0:
261                         complain("site with no public key");
262                 elif len(self._outk) != 1:
263                         debugrepr('outk ', self._outk)
264                         complain(
265  "site with multiple public keys, without --pubkeys-install (maybe --output-version=1 would help"
266                         )
267                 else:
268                         confw.write("key %s;\n"%str(self._outk[0]))
269
270 class PkmInstall(PkmBase):
271         opt = 'install'
272         help = 'install public keys in public key directory'
273         def site_start(self,pubkeys_path):
274                 PkmBase.site_start(self,pubkeys_path)
275                 self._pw=open(self._pa+'~tmp','w')
276         def site_serial(self,serial):
277                 self._pw.write('serial %s\n' % serial)
278         def write_key(self,k):
279                 wout=k.forpub(output_version,self._fs)
280                 self._pw.write(' '.join(wout))
281                 self._pw.write('\n')
282         def site_finish(self,confw):
283                 self._pw.close()
284                 os.rename(self._pa+'~tmp',self._pa+'~update')
285                 PkmElide.site_finish(self,confw)
286
287 class PkmElide(PkmBase):
288         opt = 'elide'
289         help = 'no public keys in sites.conf output nor in directory'
290         def site_finish(self,confw):
291                 confw.write("peer-keys \"%s\";\n"%self._pa);
292
293 class OpBase():
294         # Base case is reading a sites file from self.inputfilee.
295         # And writing a sites file to self.sitesfile.
296         def read_in(self):
297                 if self.inputfile is None:
298                         pfile("stdin",sys.stdin.readlines())
299                 else:
300                         pfilepath(self.inputfile)
301         def write_out(self):
302                 f=open(self.sitesfile+"-tmp",'w')
303                 f.write("# sites file autogenerated by make-secnet-sites\n")
304                 f.write("# generated %s, invoked by %s\n"%
305                         (time.asctime(time.localtime(time.time())),
306                          self.user))
307                 f.write("# use make-secnet-sites to turn this file into a\n")
308                 f.write("# valid /etc/secnet/sites.conf file\n\n")
309                 self.write_out_contents(f)
310                 f.write("# end of sites file\n")
311                 f.close()
312                 os.rename(self.sitesfile+"-tmp",self.sitesfile)
313
314 class OpConf(OpBase):
315         def is_service(self): return 0
316         def positional_args(self, av):
317                 if len(av.arg)>3:
318                         print("Too many arguments")
319                         sys.exit(1)
320                 (self.inputfile, self.outputfile) = (av.arg + [None]*2)[0:2]
321         def check_group(self,group,w): pass
322         def write_out(self):
323                 if self.outputfile is None:
324                         of=sys.stdout
325                 else:
326                         tmp_outputfile=self.outputfile+'~tmp~'
327                         of=open(tmp_outputfile,'w')
328                 outputsites(of)
329                 if self.outputfile is not None:
330                         os.rename(tmp_outputfile,self.outputfile)
331
332 class OpUserv(OpBase):
333         opts = ['--userv','-u']
334         help = 'userv service fragment update mode'
335         def is_service(self): return 1
336         def positional_args(self, av):
337                 if len(av.arg)!=4:
338                         print("Wrong number of arguments")
339                         sys.exit(1)
340                 (self.header, self.groupfiledir,
341                  self.sitesfile, self.group) = av.arg
342                 self.group = Tainted(self.group,0,'command line')
343                 # untrusted argument from caller
344                 if "USERV_USER" not in os.environ:
345                         print("Environment variable USERV_USER not found")
346                         sys.exit(1)
347                 self.user=os.environ["USERV_USER"]
348                 # Check that group is in USERV_GROUP
349                 if "USERV_GROUP" not in os.environ:
350                         print("Environment variable USERV_GROUP not found")
351                         sys.exit(1)
352                 ugs=os.environ["USERV_GROUP"]
353                 ok=0
354                 for i in ugs.split():
355                         if self.group==i: ok=1
356                 if not ok:
357                         print("caller not in group %s"%group)
358                         sys.exit(1)
359         def check_group(self,group,w):
360                 if group!=self.group: complain("Incorrect group!")
361                 w[2].groupname()
362         def read_in(self):
363                 self.headerinput=pfilepath(self.header,allow_include=True)
364                 self.userinput=sys.stdin.readlines()
365                 pfile("user input",self.userinput)
366         def write_out(self):
367                 # Put the user's input into their group file, and
368                 # rebuild the main sites file
369                 f=open(self.groupfiledir+"/T"+self.group.groupname(),'w')
370                 f.write("# Section submitted by user %s, %s\n"%
371                         (self.user,time.asctime(time.localtime(time.time()))))
372                 f.write("# Checked by make-secnet-sites version %s\n\n"
373                         %VERSION)
374                 for i in self.userinput: f.write(i)
375                 f.write("\n")
376                 f.close()
377                 os.rename(self.groupfiledir+"/T"+self.group.groupname(),
378                           self.groupfiledir+"/R"+self.group.groupname())
379                 OpBase.write_out(self)
380         def write_out_contents(self,f):
381                 for i in self.headerinput: f.write(i)
382                 files=os.listdir(self.groupfiledir)
383                 for i in files:
384                         if i[0]=='R':
385                                 j=open(self.groupfiledir+"/"+i)
386                                 f.write(j.read())
387                                 j.close()
388
389 def parse_args():
390         global opmode
391         global service
392         global prefix
393         global key_prefix
394         global debug_level
395         global output_version
396         global pubkeys_dir
397         global pubkeys_mode
398
399         ap = argparse.ArgumentParser(description='process secnet sites files')
400         def add_opmode(how):
401                 ap.add_argument(*how().opts, action=ArgActionLambda,
402                         nargs=0,
403                         fn=(lambda v,ns,*x: setattr(ns,'opmode',how)),
404                         help=how().help)
405         add_opmode(OpUserv)
406         ap.add_argument('--conf-key-prefix', action=ActionNoYes,
407                         default=True,
408                  help='prefix conf file key names derived from sites data')
409         def add_pkm(how):
410                 ap.add_argument('--pubkeys-'+how().opt, action=ArgActionLambda,
411                         nargs=0,
412                         fn=(lambda v,ns,*x: setattr(ns,'pkm',how)),
413                         help=how().help)
414         add_pkm(PkmInstall)
415         add_pkm(PkmSingle)
416         add_pkm(PkmElide)
417         ap.add_argument('--pubkeys-dir',  nargs=1,
418                         help='public key directory',
419                         default=['/var/lib/secnet/pubkeys'])
420         ap.add_argument('--output-version', nargs=1, type=int,
421                         help='sites file output version',
422                         default=[max_version])
423         ap.add_argument('--prefix', '-P', nargs=1,
424                         help='set prefix')
425         ap.add_argument('--debug', '-D', action='count', default=0)
426         ap.add_argument('arg',nargs=argparse.REMAINDER)
427         av = ap.parse_args()
428         debug_level = av.debug
429         debugrepr('av',av)
430         opmode = getattr(av,'opmode',OpConf)()
431         service = opmode.is_service()
432         prefix = '' if av.prefix is None else av.prefix[0]
433         key_prefix = av.conf_key_prefix
434         output_version = av.output_version[0]
435         pubkeys_dir = av.pubkeys_dir[0]
436         pubkeys_mode = getattr(av,'pkm',PkmSingle)
437         opmode.positional_args(av)
438
439 parse_args()
440
441 # Classes describing possible datatypes in the configuration file
442
443 class basetype:
444         "Common protocol for configuration types."
445         def add(self,obj,w):
446                 complain("%s %s already has property %s defined"%
447                         (obj.type,obj.name,w[0].raw()))
448         def forsites(self,version,copy,fs):
449                 return copy
450
451 class conflist:
452         "A list of some kind of configuration type."
453         def __init__(self,subtype,w):
454                 self.subtype=subtype
455                 self.list=[subtype(w)]
456         def add(self,obj,w):
457                 self.list.append(self.subtype(w))
458         def __str__(self):
459                 return ', '.join(map(str, self.list))
460         def forsites(self,version,copy,fs):
461                 most_recent=self.list[len(self.list)-1]
462                 return most_recent.forsites(version,copy,fs)
463 def listof(subtype):
464         return lambda w: conflist(subtype, w)
465
466 class single_ipaddr (basetype):
467         "An IP address"
468         def __init__(self,w):
469                 self.addr=ipaddress.ip_address(w[1].raw_mark_ok())
470         def __str__(self):
471                 return '"%s"'%self.addr
472
473 class networks (basetype):
474         "A set of IP addresses specified as a list of networks"
475         def __init__(self,w):
476                 self.set=ipaddrset.IPAddressSet()
477                 for i in w[1:]:
478                         x=ipaddress.ip_network(i.raw_mark_ok(),strict=True)
479                         self.set.append([x])
480         def __str__(self):
481                 return ",".join(map((lambda n: '"%s"'%n), self.set.networks()))
482
483 class dhgroup (basetype):
484         "A Diffie-Hellman group"
485         def __init__(self,w):
486                 self.mod=w[1].bignum_16('dh','dh mod')
487                 self.gen=w[2].bignum_16('dh','dh gen')
488         def __str__(self):
489                 return 'diffie-hellman("%s","%s")'%(self.mod,self.gen)
490
491 class hash (basetype):
492         "A choice of hash function"
493         def __init__(self,w):
494                 hname=w[1]
495                 self.ht=hname.raw()
496                 if (self.ht!='md5' and self.ht!='sha1'):
497                         complain("unknown hash type %s"%(self.ht))
498                         self.ht=None
499                 else:
500                         hname.raw_mark_ok()
501         def __str__(self):
502                 return '%s'%(self.ht)
503
504 class email (basetype):
505         "An email address"
506         def __init__(self,w):
507                 self.addr=w[1].email()
508         def __str__(self):
509                 return '<%s>'%(self.addr)
510
511 class boolean (basetype):
512         "A boolean"
513         def __init__(self,w):
514                 v=w[1]
515                 if re.match('[TtYy1]',v.raw()):
516                         self.b=True
517                         v.raw_mark_ok()
518                 elif re.match('[FfNn0]',v.raw()):
519                         self.b=False
520                         v.raw_mark_ok()
521                 else:
522                         complain("invalid boolean value");
523         def __str__(self):
524                 return ['False','True'][self.b]
525
526 class num (basetype):
527         "A decimal number"
528         def __init__(self,w):
529                 self.n=w[1].number(0,0x7fffffff)
530         def __str__(self):
531                 return '%d'%(self.n)
532
533 class serial (basetype):
534         def __init__(self,w):
535                 self.i=w[1].hexid(4,'serial')
536         def __str__(self):
537                 return self.i
538         def forsites(self,version,copy,fs):
539                 if version < 2: return []
540                 return copy
541
542 class address (basetype):
543         "A DNS name and UDP port number"
544         def __init__(self,w):
545                 self.adr=w[1].host()
546                 self.port=w[2].number(1,65536,'port')
547         def __str__(self):
548                 return '"%s"; port %d'%(self.adr,self.port)
549
550 class inpub (basetype):
551         def forsites(self,version,xcopy,fs):
552                 return self.forpub(version,fs)
553
554 class pubkey (inpub):
555         "Some kind of publie key"
556         def __init__(self,w):
557                 self.a=w[1].name('algname')
558                 self.d=w[2].base91();
559         def __str__(self):
560                 return 'make-public("%s","%s")'%(self.a,self.d)
561         def forpub(self,version,fs):
562                 if version < 2: return []
563                 return ['pub', self.a, self.d]
564         def okforonlykey(self,version,fs):
565                 return len(self.forpub(version,fs)) != 0
566
567 class rsakey (pubkey):
568         "An old-style RSA public key"
569         def __init__(self,w):
570                 self.l=w[1].number(0,max['rsa_bits'],'rsa len')
571                 self.e=w[2].bignum_10('rsa','rsa e')
572                 self.n=w[3].bignum_10('rsa','rsa n')
573                 if len(w) >= 5: w[4].email()
574                 self.a='rsa1'
575                 self.d=base91s_encode(b'%d %s %s' %
576                                       (self.l,
577                                        self.e.encode('ascii'),
578                                        self.n.encode('ascii')))
579                 # ^ this allows us to use the pubkey.forsites()
580                 # method for output in versions>=2
581         def __str__(self):
582                 return 'rsa-public("%s","%s")'%(self.e,self.n)
583                 # this specialisation means we can generate files
584                 # compatible with old secnet executables
585         def forpub(self,version,fs):
586                 if version < 2:
587                         if fs.pkg != '00000000': return []
588                         return ['pubkey', str(self.l), self.e, self.n]
589                 return pubkey.forpub(self,version,fs)
590
591 class rsakey_newfmt(rsakey):
592         "An old-style RSA public key in new-style sites format"
593         # This is its own class simply to have its own constructor.
594         def __init__(self,w):
595                 self.a=w[1].name()
596                 assert(self.a == 'rsa1')
597                 self.d=w[2].base91()
598                 try:
599                         w_inner=list(map(Tainted,
600                                         ['X-PUB-RSA1'] +
601                                         base91s_decode(self.d)
602                                         .decode('ascii')
603                                         .split(' ')))
604                 except UnicodeDecodeError:
605                         complain('rsa1 key in new format has bad base91')
606                 #print(repr(w_inner), file=sys.stderr)
607                 rsakey.__init__(self,w_inner)
608
609 class pubkey_group(inpub):
610         "Public key group introducer"
611         # appears in the site's list of keys mixed in with the keys
612         def __init__(self,w,fallback):
613                 self.i=w[1].hexid(4,'pkg-id')
614                 self.fallback=fallback
615         def forpub(self,version,fs):
616                 fs.pkg=self.i
617                 if version < 2: return []
618                 return ['pkgf' if self.fallback else 'pkg', self.i]
619         def okforonlykey(self,version,fs):
620                 self.forpub(version,fs)
621                 return False
622         
623 def somepubkey(w):
624         #print(repr(w), file=sys.stderr)
625         if w[0]=='pubkey':
626                 return rsakey(w)
627         elif w[0]=='pub' and w[1]=='rsa1':
628                 return rsakey_newfmt(w)
629         elif w[0]=='pub':
630                 return pubkey(w)
631         elif w[0]=='pkg':
632                 return pubkey_group(w,False)
633         elif w[0]=='pkgf':
634                 return pubkey_group(w,True)
635         else:
636                 assert(False)
637
638 # Possible properties of configuration nodes
639 keywords={
640  'contact':(email,"Contact address"),
641  'dh':(dhgroup,"Diffie-Hellman group"),
642  'hash':(hash,"Hash function"),
643  'key-lifetime':(num,"Maximum key lifetime (ms)"),
644  'setup-timeout':(num,"Key setup timeout (ms)"),
645  'setup-retries':(num,"Maximum key setup packet retries"),
646  'wait-time':(num,"Time to wait after unsuccessful key setup (ms)"),
647  'renegotiate-time':(num,"Time after key setup to begin renegotiation (ms)"),
648  'restrict-nets':(networks,"Allowable networks"),
649  'networks':(networks,"Claimed networks"),
650  'serial':(serial,"public key set serial"),
651  'pkg':(listof(somepubkey),"start of public key group",'pub'),
652  'pkgf':(listof(somepubkey),"start of fallback public key group",'pub'),
653  'pub':(listof(somepubkey),"new style public site key"),
654  'pubkey':(listof(somepubkey),"Old-style RSA public site key",'pub'),
655  'peer':(single_ipaddr,"Tunnel peer IP address"),
656  'address':(address,"External contact address and port"),
657  'mobile':(boolean,"Site is mobile"),
658 }
659
660 def sp(name,value):
661         "Simply output a property - the default case"
662         return "%s %s;\n"%(name,value)
663
664 # All levels support these properties
665 global_properties={
666         'contact':(lambda name,value:"# Contact email address: %s\n"%(value)),
667         'dh':sp,
668         'hash':sp,
669         'key-lifetime':sp,
670         'setup-timeout':sp,
671         'setup-retries':sp,
672         'wait-time':sp,
673         'renegotiate-time':sp,
674         'restrict-nets':(lambda name,value:"# restrict-nets %s\n"%value),
675 }
676
677 class level:
678         "A level in the configuration hierarchy"
679         depth=0
680         leaf=0
681         allow_properties={}
682         require_properties={}
683         def __init__(self,w):
684                 self.type=w[0].keyword()
685                 self.name=w[1].name()
686                 self.properties={}
687                 self.children={}
688         def indent(self,w,t):
689                 w.write("                 "[:t])
690         def prop_out(self,n):
691                 return self.allow_properties[n](n,str(self.properties[n]))
692         def output_props(self,w,ind):
693                 for i in sorted(self.properties.keys()):
694                         if self.allow_properties[i]:
695                                 self.indent(w,ind)
696                                 w.write("%s"%self.prop_out(i))
697         def kname(self):
698                 return ((self.type[0].upper() if key_prefix else '')
699                         + self.name)
700         def output_data(self,w,path):
701                 ind = 2*len(path)
702                 self.indent(w,ind)
703                 w.write("%s {\n"%(self.kname()))
704                 self.output_props(w,ind+2)
705                 if self.depth==1: w.write("\n");
706                 for k in sorted(self.children.keys()):
707                         c=self.children[k]
708                         c.output_data(w,path+(c,))
709                 self.indent(w,ind)
710                 w.write("};\n")
711
712 class vpnlevel(level):
713         "VPN level in the configuration hierarchy"
714         depth=1
715         leaf=0
716         type="vpn"
717         allow_properties=global_properties.copy()
718         require_properties={
719          'contact':"VPN admin contact address"
720         }
721         def __init__(self,w):
722                 level.__init__(self,w)
723         def output_vpnflat(self,w,path):
724                 "Output flattened list of site names for this VPN"
725                 ind=2*(len(path)+1)
726                 self.indent(w,ind)
727                 w.write("%s {\n"%(self.kname()))
728                 for i in self.children.keys():
729                         self.children[i].output_vpnflat(w,path+(self,))
730                 w.write("\n")
731                 self.indent(w,ind+2)
732                 w.write("all-sites %s;\n"%
733                         ','.join(map(lambda i: i.kname(),
734                                      self.children.values())))
735                 self.indent(w,ind)
736                 w.write("};\n")
737
738 class locationlevel(level):
739         "Location level in the configuration hierarchy"
740         depth=2
741         leaf=0
742         type="location"
743         allow_properties=global_properties.copy()
744         require_properties={
745          'contact':"Location admin contact address",
746         }
747         def __init__(self,w):
748                 level.__init__(self,w)
749                 self.group=w[2].groupname()
750         def output_vpnflat(self,w,path):
751                 ind=2*(len(path)+1)
752                 self.indent(w,ind)
753                 # The "path=path,self=self" abomination below exists because
754                 # Python didn't support nested_scopes until version 2.1
755                 #
756                 #"/"+self.name+"/"+i
757                 w.write("%s %s;\n"%(self.kname(),','.join(
758                         map(lambda x,path=path,self=self:
759                             '/'.join([prefix+"vpn-data"] + list(map(
760                                     lambda i: i.kname(),
761                                     path+(self,x)))),
762                             self.children.values()))))
763
764 class sitelevel(level):
765         "Site level (i.e. a leafnode) in the configuration hierarchy"
766         depth=3
767         leaf=1
768         type="site"
769         allow_properties=global_properties.copy()
770         allow_properties.update({
771          'address':sp,
772          'networks':None,
773          'peer':None,
774          'serial':None,
775          'pkg':None,
776          'pkgf':None,
777          'pub':None,
778          'pubkey':None,
779          'mobile':sp,
780         })
781         require_properties={
782          'dh':"Diffie-Hellman group",
783          'contact':"Site admin contact address",
784          'networks':"Networks claimed by the site",
785          'hash':"hash function",
786          'peer':"Gateway address of the site",
787         }
788         def mangle_name(self):
789                 return self.name.replace('/',',')
790         def pubkeys_path(self):
791                 return pubkeys_dir + '/peer.' + self.mangle_name()
792         def __init__(self,w):
793                 level.__init__(self,w)
794         def output_data(self,w,path):
795                 ind=2*len(path)
796                 np='/'.join(map(lambda i: i.name, path))
797                 self.indent(w,ind)
798                 w.write("%s {\n"%(self.kname()))
799                 self.indent(w,ind+2)
800                 w.write("name \"%s\";\n"%(np,))
801                 self.indent(w,ind+2)
802
803                 pkm = pubkeys_mode()
804                 debugrepr('pkm ',pkm)
805                 pkm.site_start(self.pubkeys_path())
806                 if 'serial' in self.properties:
807                         pkm.site_serial(self.properties['serial'])
808
809                 for k in self.properties["pub"].list:
810                         debugrepr('pubkeys ', k)
811                         pkm.write_key(k)
812
813                 pkm.site_finish(w)
814
815                 self.output_props(w,ind+2)
816                 self.indent(w,ind+2)
817                 w.write("link netlink {\n");
818                 self.indent(w,ind+4)
819                 w.write("routes %s;\n"%str(self.properties["networks"]))
820                 self.indent(w,ind+4)
821                 w.write("ptp-address %s;\n"%str(self.properties["peer"]))
822                 self.indent(w,ind+2)
823                 w.write("};\n")
824                 self.indent(w,ind)
825                 w.write("};\n")
826
827 # Levels in the configuration file
828 # (depth,properties)
829 levels={'vpn':vpnlevel, 'location':locationlevel, 'site':sitelevel}
830
831 def complain(msg):
832         "Complain about a particular input line"
833         moan(("%s line %d: "%(file,line))+msg)
834 def moan(msg):
835         "Complain about something in general"
836         global complaints
837         print(msg);
838         if complaints is None: sys.exit(1)
839         complaints=complaints+1
840
841 class UntaintedRoot():
842         def __init__(self,s): self._s=s
843         def name(self): return self._s
844         def keyword(self): return self._s
845
846 root=level([UntaintedRoot(x) for x in ['root','root']])
847 # All vpns are children of this node
848 obstack=[root]
849 allow_defs=0   # Level above which new definitions are permitted
850
851 def set_property(obj,w):
852         "Set a property on a configuration node"
853         prop=w[0]
854         propname=prop.raw_mark_ok()
855         kw=keywords[propname]
856         if len(kw) >= 3: propname=kw[2] # for aliases
857         if propname in obj.properties:
858                 obj.properties[propname].add(obj,w)
859         else:
860                 obj.properties[propname]=kw[0](w)
861         return obj.properties[propname]
862
863 class FilterState:
864         def __init__(self):
865                 self.reset()
866         def reset(self):
867                 # called when we enter a new node,
868                 # in particular, at the start of each site
869                 self.pkg = '00000000'
870
871 def pline(il,filterstate,allow_include=False):
872         "Process a configuration file line"
873         global allow_defs, obstack, root
874         w=il.rstrip('\n').split()
875         if len(w)==0: return ['']
876         w=list([Tainted(x) for x in w])
877         keyword=w[0]
878         current=obstack[len(obstack)-1]
879         copyout_core=lambda: ' '.join([ww.output() for ww in w])
880         indent='    '*len(obstack)
881         copyout=lambda: [indent + copyout_core() + '\n']
882         if keyword=='end-definitions':
883                 keyword.raw_mark_ok()
884                 allow_defs=sitelevel.depth
885                 obstack=[root]
886                 return copyout()
887         if keyword=='include':
888                 if not allow_include:
889                         complain("include not permitted here")
890                         return []
891                 if len(w) != 2:
892                         complain("include requires one argument")
893                         return []
894                 newfile=os.path.join(os.path.dirname(file),w[1].raw_mark_ok())
895                 # ^ user of "include" is trusted so raw_mark_ok is good
896                 return pfilepath(newfile,allow_include=allow_include)
897         if keyword.raw() in levels:
898                 # We may go up any number of levels, but only down by one
899                 newdepth=levels[keyword.raw_mark_ok()].depth
900                 currentdepth=len(obstack) # actually +1...
901                 if newdepth<=currentdepth:
902                         obstack=obstack[:newdepth]
903                 if newdepth>currentdepth:
904                         complain("May not go from level %d to level %d"%
905                                 (currentdepth-1,newdepth))
906                 # See if it's a new one (and whether that's permitted)
907                 # or an existing one
908                 current=obstack[len(obstack)-1]
909                 tname=w[1].name()
910                 if tname in current.children:
911                         # Not new
912                         current=current.children[tname]
913                         if current.depth==2:
914                                 opmode.check_group(current.group, w)
915                 else:
916                         # New
917                         # Ignore depth check for now
918                         nl=levels[keyword.raw()](w)
919                         if nl.depth<allow_defs:
920                                 complain("New definitions not allowed at "
921                                         "level %d"%nl.depth)
922                                 # we risk crashing if we continue
923                                 sys.exit(1)
924                         current.children[tname]=nl
925                         current=nl
926                 filterstate.reset()
927                 obstack.append(current)
928                 return copyout()
929         if keyword.raw() not in current.allow_properties:
930                 complain("Property %s not allowed at %s level"%
931                         (keyword.raw(),current.type))
932                 return []
933         elif current.depth == vpnlevel.depth < allow_defs:
934                 complain("Not allowed to set VPN properties here")
935                 return []
936         else:
937                 prop=set_property(current,w)
938                 out=[copyout_core()]
939                 out=prop.forsites(output_version,out,filterstate)
940                 if len(out)==0: return [indent + '#', copyout_core(), '\n']
941                 return [indent + ' '.join(out) + '\n']
942
943         complain("unknown keyword '%s'"%(keyword.raw()))
944
945 def pfilepath(pathname,allow_include=False):
946         f=open(pathname)
947         outlines=pfile(pathname,f.readlines(),allow_include=allow_include)
948         f.close()
949         return outlines
950
951 def pfile(name,lines,allow_include=False):
952         "Process a file"
953         global file,line
954         file=name
955         line=0
956         outlines=[]
957         filterstate = FilterState()
958         for i in lines:
959                 line=line+1
960                 if (i[0]=='#'): continue
961                 outlines += pline(i,filterstate,allow_include=allow_include)
962         return outlines
963
964 def outputsites(w):
965         "Output include file for secnet configuration"
966         w.write("# secnet sites file autogenerated by make-secnet-sites "
967                 +"version %s\n"%VERSION)
968         w.write("# %s\n"%time.asctime(time.localtime(time.time())))
969         w.write("# Command line: %s\n\n"%' '.join(sys.argv))
970
971         # Raw VPN data section of file
972         w.write(prefix+"vpn-data {\n")
973         for i in root.children.values():
974                 i.output_data(w,(i,))
975         w.write("};\n")
976
977         # Per-VPN flattened lists
978         w.write(prefix+"vpn {\n")
979         for i in root.children.values():
980                 i.output_vpnflat(w,())
981         w.write("};\n")
982
983         # Flattened list of sites
984         w.write(prefix+"all-sites %s;\n"%",".join(
985                 map(lambda x:"%svpn/%s/all-sites"%(prefix,x.kname()),
986                         root.children.values())))
987
988 line=0
989 file=None
990 complaints=0
991
992 # Sanity check section
993 # Delete nodes where leaf=0 that have no children
994
995 def live(n):
996         "Number of leafnodes below node n"
997         if n.leaf: return 1
998         for i in n.children.keys():
999                 if live(n.children[i]): return 1
1000         return 0
1001 def delempty(n):
1002         "Delete nodes that have no leafnode children"
1003         for i in list(n.children.keys()):
1004                 delempty(n.children[i])
1005                 if not live(n.children[i]):
1006                         del n.children[i]
1007
1008 # Check that all constraints are met (as far as I can tell
1009 # restrict-nets/networks/peer are the only special cases)
1010
1011 def checkconstraints(n,p,ra):
1012         new_p=p.copy()
1013         new_p.update(n.properties)
1014         for i in n.require_properties.keys():
1015                 if i not in new_p:
1016                         moan("%s %s is missing property %s"%
1017                                 (n.type,n.name,i))
1018         for i in new_p.keys():
1019                 if i not in n.allow_properties:
1020                         moan("%s %s has forbidden property %s"%
1021                                 (n.type,n.name,i))
1022         # Check address range restrictions
1023         if "restrict-nets" in n.properties:
1024                 new_ra=ra.intersection(n.properties["restrict-nets"].set)
1025         else:
1026                 new_ra=ra
1027         if "networks" in n.properties:
1028                 if not n.properties["networks"].set <= new_ra:
1029                         moan("%s %s networks out of bounds"%(n.type,n.name))
1030                 if "peer" in n.properties:
1031                         if not n.properties["networks"].set.contains(
1032                                 n.properties["peer"].addr):
1033                                 moan("%s %s peer not in networks"%(n.type,n.name))
1034         for i in n.children.keys():
1035                 checkconstraints(n.children[i],new_p,new_ra)
1036
1037 opmode.read_in()
1038
1039 delempty(root)
1040 checkconstraints(root,{},ipaddrset.complete_set())
1041
1042 if complaints>0:
1043         if complaints==1: print("There was 1 problem.")
1044         else: print("There were %d problems."%(complaints))
1045         sys.exit(1)
1046 complaints=None # arranges to crash if we complain later
1047
1048 opmode.write_out()