chiark / gitweb /
make-secnet-sites: Check input file word syntax
authorIan Jackson <ijackson@chiark.greenend.org.uk>
Thu, 24 Oct 2019 14:10:29 +0000 (15:10 +0100)
committerIan Jackson <ijackson@chiark.greenend.org.uk>
Thu, 24 Oct 2019 18:16:17 +0000 (19:16 +0100)
make-secnet sites sometimes reads untrusted input.  And we copy it to
various output files, including secnet configuration files which have
a different lexical syntax and are particularly vulnerable to a
syntax stuffing/inadequate escaping attack.

In principle we could quote everything appropriately on output but a
actually we probably just want to check it since the syntax of all
these directives and their parameters is quite restricted.

In order to ensure that we catch everything, and that if we missed a
location we get a crash rather than a security vulnerability, we take
the following approach:

Each untrusted input word is wrapped up in a new Tainted object.  The
Tainted object has a number of methods for checking and returning
values which are suitable for various purposes.  But attempts to
simply print it (eg to an output file) are made to fail.

The Tainted object keeps track internally of whether it has been
checked.  This is going to be important in a moment.

Naive call sites use straightforward methods on w[N] to get checked
values for storage in their own data structures.

Knowledgeable use sites may call .raw() to get the unchecked value,
and .raw_mark_ok() if they know that the value is good (or are about
to do something which will definitely crash if not, so that a bad
value cannot escape).

Obviously storing the results of .raw() in a call site's data
structure would escape the taint checking.  So we don't do that unless
we have done the check ourselves.

Within the Tainted implication we really wanted an error monad.  Using
python exceptions for this looked like it was going to be too
abstruse.  So we open-code the monad with a conventional `ok' local
variable.  Each entrypoint returns using ._rtn() which can
double-check that no error has been lost.

Signed-off-by: Ian Jackson <ijackson@chiark.greenend.org.uk>
make-secnet-sites

index ab125ffaed3f7de5e44c136266efdf8eef6753c1..cff283c8d9b8756026caecb7a26d3c8e6a1e00dd 100755 (executable)
@@ -61,6 +61,7 @@ import os
 import getopt
 import re
 import argparse
+import math
 
 import ipaddress
 
@@ -80,6 +81,120 @@ if version_info.major == 2:  # for python2
     import io
     open=lambda f,m='r': io.open(f,m,encoding='utf-8')
 
+max={'rsa_bits':8200,'name':33,'dh_bits':8200}
+
+class Tainted:
+       def __init__(self,s):
+               self._s=s
+               self._ok=None
+               self._line=line
+               self._file=file
+       def __eq__(self,e):
+               return self._s==e
+       def __ne__(self,e):
+               # for Python2
+               return not self.__eq__(e)
+       def __str__(self):
+               raise RuntimeError('direct use of Tainted value')
+       def __repr__(self):
+               return 'Tainted(%s)' % repr(self._s)
+
+       def _bad(self,what,why):
+               assert(self._ok is not True)
+               self._ok=False
+               complain('bad parameter: %s: %s' % (what, why))
+               return self
+
+       def _max_ok(self,what,maxlen):
+               if len(self._s) > maxlen:
+                       self._bad(what,'too long (max %d)' % maxlen)
+               return self
+
+       def _re_ok(self,bad,what,maxlen=None):
+               if maxlen is None: maxlen=max[what]
+               self._max_ok(what,maxlen)
+               if self._ok is False: return self
+               if bad.search(self._s): return self._bad(what,'bad syntax')
+               return self
+
+       def _rtnval(self, is_ok, ifgood, ifbad=''):
+               if is_ok:
+                       assert(self._ok is not False)
+                       self._ok=True
+                       return ifgood
+               else:
+                       assert(self._ok is not True)
+                       self._ok=False
+                       return ifbad
+
+       def _rtn(self, is_ok, ifbad=''):
+               return self._rtnval(is_ok, self._s, ifbad)
+
+       def raw(self):
+               return self._s
+       def raw_mark_ok(self):
+               # caller promises to throw if syntax was dangeorus
+               return self._rtn(True)
+
+       bad_name=re.compile(r'^[^a-zA-Z]|[^-_0-9a-zA-Z]')
+       # secnet accepts _ at start of names, but we reserve that
+       bad_name_counter=0
+       def name(self):
+               ok=self._re_ok(Tainted.bad_name,'name')
+               return self._rtn(ok,
+                                '_line%d_%s' % (self._line, id(self)))
+
+       def keyword(self):
+               ok=self._s in keywords or self._s in levels
+               if not ok:
+                       complain('unknown keyword %s' % self._s)
+               return self._rtn(ok)
+
+       bad_hex=re.compile(r'[^0-9a-fA-F]')
+       def bignum_16(self,kind,what):
+               maxlen=(max[kind+'_bits']+3)/4
+               ok=self._re_ok(Tainted.bad_hex,what,maxlen)
+               return self._rtn(ok)
+
+       bad_num=re.compile(r'[^0-9]')
+       def bignum_10(self,kind,what):
+               maxlen=math.ceil(max[kind+'_bits'] / math.log10(2))
+               ok=self._re_ok(Tainted.bad_num,what,maxlen)
+               return self._rtn(ok)
+
+       def number(self,minn,maxx,what='number'):
+               # not for bignums
+               ok=self._re_ok(Tainted.bad_num,what,10)
+               if ok:
+                       v=int(self._s)
+                       if v<minn or v>maxx:
+                               ok=self._bad(what,'out of range %d..%d'
+                                            % (minn,maxx))
+               return self._rtnval(ok,v,minn)
+
+       bad_host=re.compile(r'[^-\][_.:0-9a-zA-Z]')
+       # We permit _ so we can refer to special non-host domains
+       # which have A and AAAA RRs.  This is a crude check and we may
+       # still produce config files with syntactically invalid
+       # domains or addresses, but that is OK.
+       def host(self):
+               ok=self._re_ok(Tainted.bad_host,'host/address',255)
+               return self._rtn(ok)
+
+       bad_email=re.compile(r'[^-._0-9a-z@!$%^&*=+~/]')
+       # ^ This does not accept all valid email addresses.  That's
+       # not really possible with this input syntax.  It accepts
+       # all ones that don't require quoting anywhere in email
+       # protocols (and also accepts some invalid ones).
+       def email(self):
+               ok=self._re_ok(Tainted.bad_email,'email address',1023)
+               return self._rtn(ok)
+
+       bad_groupname=re.compile(r'^[^_A-Za-z]|[^-+_0-9A-Za-z]')
+       def groupname(self):
+               ok=self._re_ok(Tainted.bad_groupname,'group name',64)
+               return self._rtn(ok)
+
 def parse_args():
        global service
        global inputfile
@@ -135,7 +250,7 @@ class basetype:
        "Common protocol for configuration types."
        def add(self,obj,w):
                complain("%s %s already has property %s defined"%
-                       (obj.type,obj.name,w[0]))
+                       (obj.type,obj.name,w[0].raw()))
 
 class conflist:
        "A list of some kind of configuration type."
@@ -152,7 +267,7 @@ def listof(subtype):
 class single_ipaddr (basetype):
        "An IP address"
        def __init__(self,w):
-               self.addr=ipaddress.ip_address(w[1])
+               self.addr=ipaddress.ip_address(w[1].raw_mark_ok())
        def __str__(self):
                return '"%s"'%self.addr
 
@@ -161,7 +276,7 @@ class networks (basetype):
        def __init__(self,w):
                self.set=ipaddrset.IPAddressSet()
                for i in w[1:]:
-                       x=ipaddress.ip_network(i,strict=True)
+                       x=ipaddress.ip_network(i.raw_mark_ok(),strict=True)
                        self.set.append([x])
        def __str__(self):
                return ",".join(map((lambda n: '"%s"'%n), self.set.networks()))
@@ -169,8 +284,8 @@ class networks (basetype):
 class dhgroup (basetype):
        "A Diffie-Hellman group"
        def __init__(self,w):
-               self.mod=w[1]
-               self.gen=w[2]
+               self.mod=w[1].bignum_16('dh','dh mod')
+               self.gen=w[2].bignum_16('dh','dh gen')
        def __str__(self):
                return 'diffie-hellman("%s","%s")'%(self.mod,self.gen)
 
@@ -178,26 +293,32 @@ class hash (basetype):
        "A choice of hash function"
        def __init__(self,w):
                hname=w[1]
-               self.ht=hname
+               self.ht=hname.raw()
                if (self.ht!='md5' and self.ht!='sha1'):
                        complain("unknown hash type %s"%(self.ht))
+                       self.ht=None
+               else:
+                       hname.raw_mark_ok()
        def __str__(self):
                return '%s'%(self.ht)
 
 class email (basetype):
        "An email address"
        def __init__(self,w):
-               self.addr=w[1]
+               self.addr=w[1].email()
        def __str__(self):
                return '<%s>'%(self.addr)
 
 class boolean (basetype):
        "A boolean"
        def __init__(self,w):
-               if re.match('[TtYy1]',w[1]):
+               v=w[1]
+               if re.match('[TtYy1]',v.raw()):
                        self.b=True
-               elif re.match('[FfNn0]',w[1]):
+                       v.raw_mark_ok()
+               elif re.match('[FfNn0]',v.raw()):
                        self.b=False
+                       v.raw_mark_ok()
                else:
                        complain("invalid boolean value");
        def __str__(self):
@@ -206,26 +327,24 @@ class boolean (basetype):
 class num (basetype):
        "A decimal number"
        def __init__(self,w):
-               self.n=int(w[1])
+               self.n=w[1].number(0,0x7fffffff)
        def __str__(self):
                return '%d'%(self.n)
 
 class address (basetype):
        "A DNS name and UDP port number"
        def __init__(self,w):
-               self.adr=w[1]
-               self.port=int(w[2])
-               if (self.port<1 or self.port>65535):
-                       complain("invalid port number")
+               self.adr=w[1].host()
+               self.port=w[2].number(1,65536,'port')
        def __str__(self):
                return '"%s"; port %d'%(self.adr,self.port)
 
 class rsakey (basetype):
        "An RSA public key"
        def __init__(self,w):
-               self.l=int(w[1])
-               self.e=w[2]
-               self.n=w[3]
+               self.l=w[1].number(0,max['rsa_bits'],'rsa len')
+               self.e=w[2].bignum_10('rsa','rsa e')
+               self.n=w[3].bignum_10('rsa','rsa n')
        def __str__(self):
                return 'rsa-public("%s","%s")'%(self.e,self.n)
 
@@ -271,8 +390,8 @@ class level:
        allow_properties={}
        require_properties={}
        def __init__(self,w):
-               self.type=w[0]
-               self.name=w[1]
+               self.type=w[0].keyword()
+               self.name=w[1].name()
                self.properties={}
                self.children={}
        def indent(self,w,t):
@@ -330,7 +449,7 @@ class locationlevel(level):
        }
        def __init__(self,w):
                level.__init__(self,w)
-               self.group=w[2]
+               self.group=w[2].groupname()
        def output_vpnflat(self,w,ind,h):
                self.indent(w,ind)
                # The "h=h,self=self" abomination below exists because
@@ -399,7 +518,13 @@ def moan(msg):
        print(msg);
        complaints=complaints+1
 
-root=level(['root','root'])   # All vpns are children of this node
+class UntaintedRoot():
+       def __init__(self,s): self._s=s
+       def name(self): return self._s
+       def keyword(self): return self._s
+
+root=level([UntaintedRoot(x) for x in ['root','root']])
+# All vpns are children of this node
 obstack=[root]
 allow_defs=0   # Level above which new definitions are permitted
 prefix=''
@@ -407,19 +532,21 @@ prefix=''
 def set_property(obj,w):
        "Set a property on a configuration node"
        prop=w[0]
-       if prop in obj.properties:
-               obj.properties[prop].add(obj,w)
+       if prop.raw() in obj.properties:
+               obj.properties[prop.raw_mark_ok()].add(obj,w)
        else:
-               obj.properties[prop]=keywords[prop][0](w)
+               obj.properties[prop.raw()]=keywords[prop.raw_mark_ok()][0](w)
 
 def pline(i,allow_include=False):
        "Process a configuration file line"
        global allow_defs, obstack, root
        w=i.rstrip('\n').split()
        if len(w)==0: return [i]
+       w=list([Tainted(x) for x in w])
        keyword=w[0]
        current=obstack[len(obstack)-1]
        if keyword=='end-definitions':
+               keyword.raw_mark_ok()
                allow_defs=sitelevel.depth
                obstack=[root]
                return [i]
@@ -430,11 +557,12 @@ def pline(i,allow_include=False):
                if len(w) != 2:
                        complain("include requires one argument")
                        return []
-               newfile=os.path.join(os.path.dirname(file),w[1])
+               newfile=os.path.join(os.path.dirname(file),w[1].raw_mark_ok())
+               # ^ user of "include" is trusted so raw_mark_ok is good
                return pfilepath(newfile,allow_include=allow_include)
-       if keyword in levels:
+       if keyword.raw() in levels:
                # We may go up any number of levels, but only down by one
-               newdepth=levels[keyword].depth
+               newdepth=levels[keyword.raw_mark_ok()].depth
                currentdepth=len(obstack) # actually +1...
                if newdepth<=currentdepth:
                        obstack=obstack[:newdepth]
@@ -444,7 +572,7 @@ def pline(i,allow_include=False):
                # See if it's a new one (and whether that's permitted)
                # or an existing one
                current=obstack[len(obstack)-1]
-               tname=w[1]
+               tname=w[1].name()
                if tname in current.children:
                        # Not new
                        current=current.children[tname]
@@ -454,7 +582,7 @@ def pline(i,allow_include=False):
                else:
                        # New
                        # Ignore depth check for now
-                       nl=levels[keyword](w)
+                       nl=levels[keyword.raw()](w)
                        if nl.depth<allow_defs:
                                complain("New definitions not allowed at "
                                        "level %d"%nl.depth)
@@ -464,9 +592,9 @@ def pline(i,allow_include=False):
                        current=nl
                obstack.append(current)
                return [i]
-       if keyword not in current.allow_properties:
+       if keyword.raw() not in current.allow_properties:
                complain("Property %s not allowed at %s level"%
-                       (keyword,current.type))
+                       (keyword.raw(),current.type))
                return []
        elif current.depth == vpnlevel.depth < allow_defs:
                complain("Not allowed to set VPN properties here")
@@ -475,7 +603,7 @@ def pline(i,allow_include=False):
                set_property(current,w)
                return [i]
 
-       complain("unknown keyword '%s'"%(keyword))
+       complain("unknown keyword '%s'"%(keyword.raw()))
 
 def pfilepath(pathname,allow_include=False):
        f=open(pathname)