chiark / gitweb /
Found in crybaby's working tree.
[chopwood] / backend.py
index 1967cda5dd76b2440331fd5358fe49e9d3b049dc..2596306363e7c6171107bdbfa8d26b81de8f40c9 100644 (file)
@@ -25,6 +25,9 @@
 
 from __future__ import with_statement
 
+from auto import HOME
+import errno as E
+import itertools as I
 import os as OS; ENV = OS.environ
 
 import config as CONF; CFG = CONF.CFG
@@ -36,7 +39,75 @@ import util as U
 CONF.DEFAULTS.update(
 
   ## A directory in which we can create lockfiles.
-  LOCKDIR = OS.path.join(ENV['HOME'], 'var', 'lock', 'chpwd'))
+  LOCKDIR = OS.path.join(HOME, 'lock'))
+
+###--------------------------------------------------------------------------
+### Utilities.
+
+def fill_in_fields(fno_user, fno_passwd, fno_map,
+                   user, passwd, args, defaults = None):
+  """
+  Return a vector of filled-in fields.
+
+  The FNO_... arguments give field numbers: FNO_USER and FNO_PASSWD give the
+  positions for the username and password fields, respectively; and FNO_MAP
+  is a sequence of (NAME, POS) pairs.  The USER and PASSWD arguments give the
+  actual user name and password values; ARGS are the remaining arguments,
+  maybe in the form `NAME=VALUE'.
+
+  If DEFAULTS is given, it is a fully-filled-in sequence of records.  The
+  user name and password are taken from the DEFAULTS if USER and PASSWD are
+  `None' on entry; they cannot be set from ARGS.
+  """
+
+  ## Prepare the result vector, and set up some data structures.
+  n = 2 + len(fno_map)
+  fmap = {}
+  rmap = map(int, xrange(n))
+  ok = True
+  if fno_user >= n or fno_passwd >= n: ok = False
+  for k, i in fno_map:
+    fmap[k] = i
+    if i in rmap: ok = False
+    else: rmap[i] = "`%s'" % k
+    if i >= n: ok = False
+  if not ok:
+    raise U.ExpectedError, \
+        (500, "Fields specified aren't contiguous")
+
+  ## Prepare the new record's fields.
+  f = [None]*n
+  if user is not None: f[fno_user] = user
+  elif defaults is not None: f[no_user] = defaults[no_user]
+  else: raise U.ExpectedError, (500, "No user name given")
+  if passwd is not None: f[fno_passwd] = passwd
+  elif defaults is not None: f[no_passwd] = defaults[no_passwd]
+  else: raise U.ExpectedError, (500, "No password given")
+
+  for a in args:
+    if '=' in a:
+      k, v = a.split('=', 1)
+      try: i = fmap[k]
+      except KeyError: raise U.ExpectedError, (400, "Unknown field `%s'" % k)
+    else:
+      for i in xrange(n):
+        if f[i] is None: break
+      else:
+        raise U.ExpectedError, (500, "All fields already populated")
+      v = a
+    if f[i] is not None:
+      raise U.ExpectedError, (400, "Field %s is already set" % rmap[i])
+    f[i] = v
+
+  ## Check that the vector of fields is properly set up.  Copy unset values
+  ## from the defaults if they're available.
+  for i in xrange(n):
+    if f[i] is not None: pass
+    elif defaults is not None: f[i] = defaults[i]
+    else: raise U.ExpectedError, (500, "Field %s is unset" % rmap[i])
+
+  ## Done.
+  return f
 
 ###--------------------------------------------------------------------------
 ### Protocol.
@@ -76,6 +147,8 @@ class BasicRecord (object):
     me._be = backend
   def write(me):
     me._be._update(me)
+  def remove(me):
+    me._be._remove(me)
 
 class TrivialRecord (BasicRecord):
   """
@@ -104,7 +177,7 @@ class FlatFileRecord (BasicRecord):
   is careful to do this).
   """
 
-  def __init__(me, line, delim, fmap, *args, **kw):
+  def __init__(me, line, delim, fno_user, fno_passwd, fmap, *args, **kw):
     """
     Initialize the record, splitting the LINE into fields separated by DELIM,
     and setting attributes under control of FMAP.
@@ -113,6 +186,8 @@ class FlatFileRecord (BasicRecord):
     line = line.rstrip('\n')
     fields = line.split(delim)
     me._delim = delim
+    me._fno_user = fno_user
+    me._fno_passwd = fno_passwd
     me._fmap = fmap
     me._raw = fields
     for k, v in fmap.iteritems():
@@ -138,6 +213,14 @@ class FlatFileRecord (BasicRecord):
       fields[v] = val
     return me._delim.join(fields) + '\n'
 
+  def edit(me, args):
+    """
+    Modify the record as described by the ARGS.
+
+    Fields other than the user name and password can be modified.
+    """
+    ff = fill_in_fields(
+
 class FlatFileBackend (object):
   """
   Password storage in a flat passwd(5)-style file.
@@ -147,10 +230,13 @@ class FlatFileBackend (object):
   specified by the DELIM constructor argument.
 
   The file is updated by writing a new version alongside, as `FILE.new', and
-  renaming it over the old version.  If a LOCK file is named then an
-  exclusive fcntl(2)-style lock is taken out on `LOCKDIR/LOCK' (creating the
-  file if necessary) during the update operation.  Use of a lockfile is
-  strongly recommended.
+  renaming it over the old version.  If a LOCK is provided then this is done
+  while holding a lock.  By default, an exclusive fcntl(2)-style lock is
+  taken out on `LOCKDIR/LOCK' (creating the file if necessary) during the
+  update operation, but subclasses can override the `dolocked' method to
+  provide alternative locking behaviour; the LOCK parameter is not
+  interpreted by any other methods.  Use of a lockfile is strongly
+  recommended.
 
   The DELIM constructor argument specifies the delimiter character used when
   splitting lines into fields.  The USER and PASSWD arguments give the field
@@ -181,8 +267,38 @@ class FlatFileBackend (object):
           return rec
     raise UnknownUser, user
 
-  def _update(me, rec):
-    """Update the record REC in the file."""
+  def create(me, user, passwd, args):
+    """
+    Create a new record for the USER.
+
+    The new record has the given PASSWD, and other fields are set from ARGS.
+    Those ARGS of the form `KEY=VALUE' set the appropriately named fields (as
+    set up by the constructor); other ARGS fill in unset fields, left to
+    right.
+    """
+
+    f = fill_in_fields(me._fmap['user'], me._fmap['passwd'],
+                       [(k[2:], i)
+                        for k, i in me._fmap.iteritems()
+                        if k.startswith('f_')],
+                       user, passwd, args)
+    r = FlatFileRecord(me._delim.join(f), me._delim, me._fmap, backend = me)
+    me._rewrite('create', r)
+
+  def _rewrite(me, op, rec):
+    """
+    Rewrite the file, according to OP.
+
+    The OP may be one of the following.
+
+    `create'            There must not be a record matching REC; add a new
+                        one.
+
+    `remove'            There must be a record matching REC: remove it.
+
+    `update'            There must be a record matching REC: write REC in its
+                        place.
+    """
 
     ## The main update function.
     def doit():
@@ -200,14 +316,26 @@ class FlatFileBackend (object):
 
         ## Copy the old file to the new one, changing the user's record if
         ## and when we encounter it.
+        found = False
         with OS.fdopen(fd, 'w') as f_out:
           with open(me._file) as f_in:
             for line in f_in:
               r = me._parse(line)
               if r.user != rec.user:
                 f_out.write(line)
+              elif op == 'create':
+                raise U.ExpectedError, \
+                    (500, "Record for `%s' already exists" % rec.user)
               else:
-                f_out.write(rec._format())
+                found = True
+                if op != 'remove': f_out.write(rec._format())
+          if found:
+            pass
+          elif op == 'create':
+            f_out.write(rec._format())
+          else:
+            raise U.ExpectedError, \
+                (500, "Record for `%s' not found" % rec.user)
 
         ## Update the permissions on the new file.  Don't try to fix the
         ## ownership (we shouldn't be running as root) or the group (the
@@ -228,16 +356,34 @@ class FlatFileBackend (object):
 
     ## If there's a lockfile, then acquire it around the meat of this
     ## function; otherwise just do the job.
-    if me._lock is None:
-      doit()
-    else:
-      with U.lockfile(OS.path.join(CFG.LOCKDIR, me._lock), 5):
-        doit()
+    if me._lock is None: doit()
+    else: me.dolocked(me._lock, doit)
+
+  def dolocked(me, lock, func):
+    """
+    Call FUNC with the LOCK held.
+
+    Subclasses can override this method in order to provide alternative
+    locking functionality.
+    """
+    try: OS.mkdir(CFG.LOCKDIR)
+    except OSError, e:
+      if e.errno != E.EEXIST: raise
+    with U.lockfile(OS.path.join(CFG.LOCKDIR, lock), 5):
+      func()
 
   def _parse(me, line):
     """Convenience function for constructing a record."""
     return FlatFileRecord(line, me._delim, me._fmap, backend = me)
 
+  def _update(me, rec):
+    """Update the record REC in the file."""
+    me._rewrite('update', rec)
+
+  def _remove(me, rec):
+    """Update the record REC in the file."""
+    me._rewrite('remove', rec)
+
 CONF.export('FlatFileBackend')
 
 ###--------------------------------------------------------------------------
@@ -294,6 +440,36 @@ class DatabaseBackend (object):
       setattr(rec, 'f_' + f, v)
     return rec
 
+  def create(me, user, passwd, args):
+    """
+    Create a new record for the named USER.
+
+    The new record has the given PASSWD, and other fields are set from ARGS.
+    Those ARGS of the form `KEY=VALUE' set the appropriately named fields (as
+    set up by the constructor); other ARGS fill in unset fields, left to
+    right, in the order given to the constructor.
+    """
+
+    tags = ['user', 'passwd'] + \
+        ['t_%d' % 0 for i in xrange(len(me._fields))]
+    f = fill_in_fields(0, 1, list(I.izip(me._fields, I.count(2))),
+                       user, passwd, args)
+    me._connect()
+    with me._db:
+      me._db.execute("INSERT INTO %s (%s) VALUES (%s)" %
+                     (me._table,
+                      ', '.join([me._user, me._passwd] + me._fields),
+                      ', '.join(['$%s' % t for t in tags])),
+                     **dict(I.izip(tags, f)))
+
+  def _remove(me, rec):
+    """Remove the record REC from the database."""
+    me._connect()
+    with me._db:
+      me._db.execute("DELETE FROM %s WHERE %s = $user" %
+                     (me._table, me._user),
+                     user = rec.user)
+
   def _update(me, rec):
     """Update the record REC in the database."""
     me._connect()