chiark / gitweb /
cgi.py: Set the default static URL prefix from user's `SCRIPT_NAME'.
[chopwood] / backend.py
index 1967cda5dd76b2440331fd5358fe49e9d3b049dc..d6fac7832d6bf722b93cef3dadddd0def542ffd3 100644 (file)
@@ -25,6 +25,7 @@
 
 from __future__ import with_statement
 
+import itertools as I
 import os as OS; ENV = OS.environ
 
 import config as CONF; CFG = CONF.CFG
@@ -38,6 +39,62 @@ CONF.DEFAULTS.update(
   ## A directory in which we can create lockfiles.
   LOCKDIR = OS.path.join(ENV['HOME'], 'var', 'lock', 'chpwd'))
 
+###--------------------------------------------------------------------------
+### Utilities.
+
+def fill_in_fields(fno_user, fno_passwd, fno_map, user, passwd, args):
+  """
+  Return a vector of filled-in fields.
+
+  The FNO_... arguments give field numbers: FNO_USER and FNO_PASSWD give the
+  positions for the username and password fields, respectively; and FNO_MAP
+  is a sequence of (NAME, POS) pairs.  The USER and PASSWD arguments give the
+  actual user name and password values; ARGS are the remaining arguments,
+  maybe in the form `NAME=VALUE'.
+  """
+
+  ## Prepare the result vector, and set up some data structures.
+  n = 2 + len(fno_map)
+  fmap = {}
+  rmap = map(int, xrange(n))
+  ok = True
+  if fno_user >= n or fno_passwd >= n: ok = False
+  for k, i in fno_map:
+    fmap[k] = i
+    rmap[i] = "`%s'" % k
+    if i >= n: ok = False
+  if not ok:
+    raise U.ExpectedError, \
+        (500, "Fields specified aren't contiguous")
+
+  ## Prepare the new record's fields.
+  f = [None]*n
+  f[fno_user] = user
+  f[fno_passwd] = passwd
+
+  for a in args:
+    if '=' in a:
+      k, v = a.split('=', 1)
+      try: i = fmap[k]
+      except KeyError: raise U.ExpectedError, (400, "Unknown field `%s'" % k)
+    else:
+      for i in xrange(n):
+        if f[i] is None: break
+      else:
+        raise U.ExpectedError, (500, "All fields already populated")
+      v = a
+    if f[i] is not None:
+      raise U.ExpectedError, (400, "Field %s is already set" % rmap[i])
+    f[i] = v
+
+  ## Check that the vector of fields is properly set up.
+  for i in xrange(n):
+    if f[i] is None:
+      raise U.ExpectedError, (500, "Field %s is unset" % rmap[i])
+
+  ## Done.
+  return f
+
 ###--------------------------------------------------------------------------
 ### Protocol.
 ###
@@ -76,6 +133,8 @@ class BasicRecord (object):
     me._be = backend
   def write(me):
     me._be._update(me)
+  def remove(me):
+    me._be._remove(me)
 
 class TrivialRecord (BasicRecord):
   """
@@ -181,8 +240,38 @@ class FlatFileBackend (object):
           return rec
     raise UnknownUser, user
 
-  def _update(me, rec):
-    """Update the record REC in the file."""
+  def create(me, user, passwd, args):
+    """
+    Create a new record for the USER.
+
+    The new record has the given PASSWD, and other fields are set from ARGS.
+    Those ARGS of the form `KEY=VALUE' set the appropriately named fields (as
+    set up by the constructor); other ARGS fill in unset fields, left to
+    right.
+    """
+
+    f = fill_in_fields(me._fmap['user'], me._fmap['passwd'],
+                       [(k[2:], i)
+                        for k, i in me._fmap.iteritems()
+                        if k.startswith('f_')],
+                       user, passwd, args)
+    r = FlatFileRecord(me._delim.join(f), me._delim, me._fmap, backend = me)
+    me._rewrite('create', r)
+
+  def _rewrite(me, op, rec):
+    """
+    Rewrite the file, according to OP.
+
+    The OP may be one of the following.
+
+    `create'            There must not be a record matching REC; add a new
+                        one.
+
+    `remove'            There must be a record matching REC: remove it.
+
+    `update'            There must be a record matching REC: write REC in its
+                        place.
+    """
 
     ## The main update function.
     def doit():
@@ -200,14 +289,26 @@ class FlatFileBackend (object):
 
         ## Copy the old file to the new one, changing the user's record if
         ## and when we encounter it.
+        found = False
         with OS.fdopen(fd, 'w') as f_out:
           with open(me._file) as f_in:
             for line in f_in:
               r = me._parse(line)
               if r.user != rec.user:
                 f_out.write(line)
+              elif op == 'create':
+                raise U.ExpectedError, \
+                    (500, "Record for `%s' already exists" % rec.user)
               else:
-                f_out.write(rec._format())
+                found = True
+                if op != 'remove': f_out.write(rec._format())
+          if found:
+            pass
+          elif op == 'create':
+            f_out.write(rec._format())
+          else:
+            raise U.ExpectedError, \
+                (500, "Record for `%s' not found" % rec.user)
 
         ## Update the permissions on the new file.  Don't try to fix the
         ## ownership (we shouldn't be running as root) or the group (the
@@ -238,6 +339,14 @@ class FlatFileBackend (object):
     """Convenience function for constructing a record."""
     return FlatFileRecord(line, me._delim, me._fmap, backend = me)
 
+  def _update(me, rec):
+    """Update the record REC in the file."""
+    me._rewrite('update', rec)
+
+  def _remove(me, rec):
+    """Update the record REC in the file."""
+    me._rewrite('remove', rec)
+
 CONF.export('FlatFileBackend')
 
 ###--------------------------------------------------------------------------
@@ -294,6 +403,36 @@ class DatabaseBackend (object):
       setattr(rec, 'f_' + f, v)
     return rec
 
+  def create(me, user, passwd, args):
+    """
+    Create a new record for the named USER.
+
+    The new record has the given PASSWD, and other fields are set from ARGS.
+    Those ARGS of the form `KEY=VALUE' set the appropriately named fields (as
+    set up by the constructor); other ARGS fill in unset fields, left to
+    right, in the order given to the constructor.
+    """
+
+    tags = ['user', 'passwd'] + \
+        ['t_%d' % 0 for i in xrange(len(me._fields))]
+    f = fill_in_fields(0, 1, list(I.izip(me._fields, I.count(2))),
+                       user, passwd, args)
+    me._connect()
+    with me._db:
+      me._db.execute("INSERT INTO %s (%s) VALUES (%s)" %
+                     (me._table,
+                      ', '.join([me._user, me._passwd] + me._fields),
+                      ', '.join(['$%s' % t for t in tags])),
+                     **dict(I.izip(tags, f)))
+
+  def _remove(me, rec):
+    """Remove the record REC from the database."""
+    me._connect()
+    with me._db:
+      me._db.execute("DELETE FROM %s WHERE %s = $user" %
+                     (me._table, me._user),
+                     user = rec.user)
+
   def _update(me, rec):
     """Update the record REC in the database."""
     me._connect()