chiark / gitweb /
Generate binary diffs by default
[stgit] / stgit / lib / git.py
index 2af1844e4c887ea079fb8393c661f8f8650b0c4d..65d2a6c3f68a0e82a2bea4c784a4e4e3caa6e93b 100644 (file)
+"""A Python class hierarchy wrapping a git repository and its
+contents."""
+
 import os, os.path, re
+from datetime import datetime, timedelta, tzinfo
+
 from stgit import exception, run, utils
+from stgit.config import config
+
+class Immutable(object):
+    """I{Immutable} objects cannot be modified once created. Any
+    modification methods will return a new object, leaving the
+    original object as it was.
+
+    The reason for this is that we want to be able to represent git
+    objects, which are immutable, and want to be able to create new
+    git objects that are just slight modifications of other git
+    objects. (Such as, for example, modifying the commit message of a
+    commit object while leaving the rest of it intact. This involves
+    creating a whole new commit object that's exactly like the old one
+    except for the commit message.)
+
+    The L{Immutable} class doesn't actually enforce immutability --
+    that is up to the individual immutable subclasses. It just serves
+    as documentation."""
 
 class RepositoryException(exception.StgException):
-    pass
+    """Base class for all exceptions due to failed L{Repository}
+    operations."""
+
+class BranchException(exception.StgException):
+    """Exception raised by failed L{Branch} operations."""
+
+class DateException(exception.StgException):
+    """Exception raised when a date+time string could not be parsed."""
+    def __init__(self, string, type):
+        exception.StgException.__init__(
+            self, '"%s" is not a valid %s' % (string, type))
 
 class DetachedHeadException(RepositoryException):
+    """Exception raised when HEAD is detached (that is, there is no
+    current branch)."""
     def __init__(self):
         RepositoryException.__init__(self, 'Not on any branch')
 
 class Repr(object):
+    """Utility class that defines C{__reps__} in terms of C{__str__}."""
     def __repr__(self):
         return str(self)
 
 class NoValue(object):
+    """A handy default value that is guaranteed to be distinct from any
+    real argument value."""
     pass
 
 def make_defaults(defaults):
-    def d(val, attr):
+    def d(val, attr, default_fun = lambda: None):
         if val != NoValue:
             return val
         elif defaults != NoValue:
             return getattr(defaults, attr)
         else:
-            return None
+            return default_fun()
     return d
 
-class Person(Repr):
-    """Immutable."""
+class TimeZone(tzinfo, Repr):
+    """A simple time zone class for static offsets from UTC. (We have to
+    define our own since Python's standard library doesn't define any
+    time zone classes.)"""
+    def __init__(self, tzstring):
+        m = re.match(r'^([+-])(\d{2}):?(\d{2})$', tzstring)
+        if not m:
+            raise DateException(tzstring, 'time zone')
+        sign = int(m.group(1) + '1')
+        try:
+            self.__offset = timedelta(hours = sign*int(m.group(2)),
+                                      minutes = sign*int(m.group(3)))
+        except OverflowError:
+            raise DateException(tzstring, 'time zone')
+        self.__name = tzstring
+    def utcoffset(self, dt):
+        return self.__offset
+    def tzname(self, dt):
+        return self.__name
+    def dst(self, dt):
+        return timedelta(0)
+    def __str__(self):
+        return self.__name
+
+class Date(Immutable, Repr):
+    """Represents a timestamp used in git commits."""
+    def __init__(self, datestring):
+        # Try git-formatted date.
+        m = re.match(r'^(\d+)\s+([+-]\d\d:?\d\d)$', datestring)
+        if m:
+            try:
+                self.__time = datetime.fromtimestamp(int(m.group(1)),
+                                                     TimeZone(m.group(2)))
+            except ValueError:
+                raise DateException(datestring, 'date')
+            return
+
+        # Try iso-formatted date.
+        m = re.match(r'^(\d{4})-(\d{2})-(\d{2})\s+(\d{2}):(\d{2}):(\d{2})\s+'
+                     + r'([+-]\d\d:?\d\d)$', datestring)
+        if m:
+            try:
+                self.__time = datetime(
+                    *[int(m.group(i + 1)) for i in xrange(6)],
+                    **{'tzinfo': TimeZone(m.group(7))})
+            except ValueError:
+                raise DateException(datestring, 'date')
+            return
+
+        raise DateException(datestring, 'date')
+    def __str__(self):
+        return self.isoformat()
+    def isoformat(self):
+        """Human-friendly ISO 8601 format."""
+        return '%s %s' % (self.__time.replace(tzinfo = None).isoformat(' '),
+                          self.__time.tzinfo)
+    @classmethod
+    def maybe(cls, datestring):
+        """Return a new object initialized with the argument if it contains a
+        value (otherwise, just return the argument)."""
+        if datestring in [None, NoValue]:
+            return datestring
+        return cls(datestring)
+
+class Person(Immutable, Repr):
+    """Represents an author or committer in a git commit object. Contains
+    name, email and timestamp."""
     def __init__(self, name = NoValue, email = NoValue,
                  date = NoValue, defaults = NoValue):
         d = make_defaults(defaults)
         self.__name = d(name, 'name')
         self.__email = d(email, 'email')
         self.__date = d(date, 'date')
+        assert isinstance(self.__date, Date) or self.__date in [None, NoValue]
     name = property(lambda self: self.__name)
     email = property(lambda self: self.__email)
+    name_email = property(lambda self: '%s <%s>' % (self.name, self.email))
     date = property(lambda self: self.__date)
     def set_name(self, name):
         return type(self)(name = name, defaults = self)
@@ -43,33 +148,184 @@ class Person(Repr):
     def set_date(self, date):
         return type(self)(date = date, defaults = self)
     def __str__(self):
-        return '%s <%s> %s' % (self.name, self.email, self.date)
+        return '%s %s' % (self.name_email, self.date)
     @classmethod
     def parse(cls, s):
         m = re.match(r'^([^<]*)<([^>]*)>\s+(\d+\s+[+-]\d{4})$', s)
         assert m
         name = m.group(1).strip()
         email = m.group(2)
-        date = m.group(3)
+        date = Date(m.group(3))
         return cls(name, email, date)
+    @classmethod
+    def user(cls):
+        if not hasattr(cls, '__user'):
+            cls.__user = cls(name = config.get('user.name'),
+                             email = config.get('user.email'))
+        return cls.__user
+    @classmethod
+    def author(cls):
+        if not hasattr(cls, '__author'):
+            cls.__author = cls(
+                name = os.environ.get('GIT_AUTHOR_NAME', NoValue),
+                email = os.environ.get('GIT_AUTHOR_EMAIL', NoValue),
+                date = Date.maybe(os.environ.get('GIT_AUTHOR_DATE', NoValue)),
+                defaults = cls.user())
+        return cls.__author
+    @classmethod
+    def committer(cls):
+        if not hasattr(cls, '__committer'):
+            cls.__committer = cls(
+                name = os.environ.get('GIT_COMMITTER_NAME', NoValue),
+                email = os.environ.get('GIT_COMMITTER_EMAIL', NoValue),
+                date = Date.maybe(
+                    os.environ.get('GIT_COMMITTER_DATE', NoValue)),
+                defaults = cls.user())
+        return cls.__committer
+
+class GitObject(Immutable, Repr):
+    """Base class for all git objects. One git object is represented by at
+    most one C{GitObject}, which makes it possible to compare them
+    using normal Python object comparison; it also ensures we don't
+    waste more memory than necessary."""
 
-class Tree(Repr):
-    """Immutable."""
-    def __init__(self, sha1):
+class BlobData(Immutable, Repr):
+    """Represents the data contents of a git blob object."""
+    def __init__(self, string):
+        self.__string = str(string)
+    str = property(lambda self: self.__string)
+    def commit(self, repository):
+        """Commit the blob.
+        @return: The committed blob
+        @rtype: L{Blob}"""
+        sha1 = repository.run(['git', 'hash-object', '-w', '--stdin']
+                              ).raw_input(self.str).output_one_line()
+        return repository.get_blob(sha1)
+
+class Blob(GitObject):
+    """Represents a git blob object. All the actual data contents of the
+    blob object is stored in the L{data} member, which is a
+    L{BlobData} object."""
+    typename = 'blob'
+    default_perm = '100644'
+    def __init__(self, repository, sha1):
+        self.__repository = repository
         self.__sha1 = sha1
     sha1 = property(lambda self: self.__sha1)
     def __str__(self):
-        return 'Tree<%s>' % self.sha1
+        return 'Blob<%s>' % self.sha1
+    @property
+    def data(self):
+        return BlobData(self.__repository.cat_object(self.sha1))
+
+class ImmutableDict(dict):
+    """A dictionary that cannot be modified once it's been created."""
+    def error(*args, **kwargs):
+        raise TypeError('Cannot modify immutable dict')
+    __delitem__ = error
+    __setitem__ = error
+    clear = error
+    pop = error
+    popitem = error
+    setdefault = error
+    update = error
 
-class Commitdata(Repr):
-    """Immutable."""
+class TreeData(Immutable, Repr):
+    """Represents the data contents of a git tree object."""
+    @staticmethod
+    def __x(po):
+        if isinstance(po, GitObject):
+            perm, object = po.default_perm, po
+        else:
+            perm, object = po
+        return perm, object
+    def __init__(self, entries):
+        """Create a new L{TreeData} object from the given mapping from names
+        (strings) to either (I{permission}, I{object}) tuples or just
+        objects."""
+        self.__entries = ImmutableDict((name, self.__x(po))
+                                       for (name, po) in entries.iteritems())
+    entries = property(lambda self: self.__entries)
+    """Map from name to (I{permission}, I{object}) tuple."""
+    def set_entry(self, name, po):
+        """Create a new L{TreeData} object identical to this one, except that
+        it maps C{name} to C{po}.
+
+        @param name: Name of the changed mapping
+        @type name: C{str}
+        @param po: Value of the changed mapping
+        @type po: L{Blob} or L{Tree} or (C{str}, L{Blob} or L{Tree})
+        @return: The new L{TreeData} object
+        @rtype: L{TreeData}"""
+        e = dict(self.entries)
+        e[name] = self.__x(po)
+        return type(self)(e)
+    def del_entry(self, name):
+        """Create a new L{TreeData} object identical to this one, except that
+        it doesn't map C{name} to anything.
+
+        @param name: Name of the deleted mapping
+        @type name: C{str}
+        @return: The new L{TreeData} object
+        @rtype: L{TreeData}"""
+        e = dict(self.entries)
+        del e[name]
+        return type(self)(e)
+    def commit(self, repository):
+        """Commit the tree.
+        @return: The committed tree
+        @rtype: L{Tree}"""
+        listing = ''.join(
+            '%s %s %s\t%s\0' % (mode, obj.typename, obj.sha1, name)
+            for (name, (mode, obj)) in self.entries.iteritems())
+        sha1 = repository.run(['git', 'mktree', '-z']
+                              ).raw_input(listing).output_one_line()
+        return repository.get_tree(sha1)
+    @classmethod
+    def parse(cls, repository, s):
+        """Parse a raw git tree description.
+
+        @return: A new L{TreeData} object
+        @rtype: L{TreeData}"""
+        entries = {}
+        for line in s.split('\0')[:-1]:
+            m = re.match(r'^([0-7]{6}) ([a-z]+) ([0-9a-f]{40})\t(.*)$', line)
+            assert m
+            perm, type, sha1, name = m.groups()
+            entries[name] = (perm, repository.get_object(type, sha1))
+        return cls(entries)
+
+class Tree(GitObject):
+    """Represents a git tree object. All the actual data contents of the
+    tree object is stored in the L{data} member, which is a
+    L{TreeData} object."""
+    typename = 'tree'
+    default_perm = '040000'
+    def __init__(self, repository, sha1):
+        self.__sha1 = sha1
+        self.__repository = repository
+        self.__data = None
+    sha1 = property(lambda self: self.__sha1)
+    @property
+    def data(self):
+        if self.__data == None:
+            self.__data = TreeData.parse(
+                self.__repository,
+                self.__repository.run(['git', 'ls-tree', '-z', self.sha1]
+                                      ).raw_output())
+        return self.__data
+    def __str__(self):
+        return 'Tree<sha1: %s>' % self.sha1
+
+class CommitData(Immutable, Repr):
+    """Represents the data contents of a git commit object."""
     def __init__(self, tree = NoValue, parents = NoValue, author = NoValue,
                  committer = NoValue, message = NoValue, defaults = NoValue):
         d = make_defaults(defaults)
         self.__tree = d(tree, 'tree')
         self.__parents = d(parents, 'parents')
-        self.__author = d(author, 'author')
-        self.__committer = d(committer, 'committer')
+        self.__author = d(author, 'author', Person.author)
+        self.__committer = d(committer, 'committer', Person.committer)
         self.__message = d(message, 'message')
     tree = property(lambda self: self.__tree)
     parents = property(lambda self: self.__parents)
@@ -106,12 +362,34 @@ class Commitdata(Repr):
             parents = None
         else:
             parents = [p.sha1 for p in self.parents]
-        return ('Commitdata<tree: %s, parents: %s, author: %s,'
+        return ('CommitData<tree: %s, parents: %s, author: %s,'
                 ' committer: %s, message: "%s">'
                 ) % (tree, parents, self.author, self.committer, self.message)
+    def commit(self, repository):
+        """Commit the commit.
+        @return: The committed commit
+        @rtype: L{Commit}"""
+        c = ['git', 'commit-tree', self.tree.sha1]
+        for p in self.parents:
+            c.append('-p')
+            c.append(p.sha1)
+        env = {}
+        for p, v1 in ((self.author, 'AUTHOR'),
+                       (self.committer, 'COMMITTER')):
+            if p != None:
+                for attr, v2 in (('name', 'NAME'), ('email', 'EMAIL'),
+                                 ('date', 'DATE')):
+                    if getattr(p, attr) != None:
+                        env['GIT_%s_%s' % (v1, v2)] = str(getattr(p, attr))
+        sha1 = repository.run(c, env = env).raw_input(self.message
+                                                      ).output_one_line()
+        return repository.get_commit(sha1)
     @classmethod
     def parse(cls, repository, s):
-        cd = cls()
+        """Parse a raw git commit description.
+        @return: A new L{CommitData} object
+        @rtype: L{CommitData}"""
+        cd = cls(parents = [])
         lines = list(s.splitlines(True))
         for i in xrange(len(lines)):
             line = lines[i].strip()
@@ -130,8 +408,11 @@ class Commitdata(Repr):
                 assert False
         assert False
 
-class Commit(Repr):
-    """Immutable."""
+class Commit(GitObject):
+    """Represents a git commit object. All the actual data contents of the
+    commit object is stored in the L{data} member, which is a
+    L{CommitData} object."""
+    typename = 'commit'
     def __init__(self, repository, sha1):
         self.__sha1 = sha1
         self.__repository = repository
@@ -140,7 +421,7 @@ class Commit(Repr):
     @property
     def data(self):
         if self.__data == None:
-            self.__data = Commitdata.parse(
+            self.__data = CommitData.parse(
                 self.__repository,
                 self.__repository.cat_object(self.sha1))
         return self.__data
@@ -148,21 +429,26 @@ class Commit(Repr):
         return 'Commit<sha1: %s, data: %s>' % (self.sha1, self.__data)
 
 class Refs(object):
+    """Accessor for the refs stored in a git repository. Will
+    transparently cache the values of all refs."""
     def __init__(self, repository):
         self.__repository = repository
         self.__refs = None
     def __cache_refs(self):
+        """(Re-)Build the cache of all refs in the repository."""
         self.__refs = {}
         for line in self.__repository.run(['git', 'show-ref']).output_lines():
             m = re.match(r'^([0-9a-f]{40})\s+(\S+)$', line)
             sha1, ref = m.groups()
             self.__refs[ref] = sha1
     def get(self, ref):
-        """Throws KeyError if ref doesn't exist."""
+        """Get the Commit the given ref points to. Throws KeyError if ref
+        doesn't exist."""
         if self.__refs == None:
             self.__cache_refs()
         return self.__repository.get_commit(self.__refs[ref])
     def exists(self, ref):
+        """Check if the given ref exists."""
         try:
             self.get(ref)
         except KeyError:
@@ -170,6 +456,8 @@ class Refs(object):
         else:
             return True
     def set(self, ref, commit, msg):
+        """Write the sha1 of the given Commit to the ref. The ref may or may
+        not already exist."""
         if self.__refs == None:
             self.__cache_refs()
         old_sha1 = self.__refs.get(ref, '0'*40)
@@ -179,6 +467,7 @@ class Refs(object):
                                    ref, new_sha1, old_sha1]).no_output()
             self.__refs[ref] = new_sha1
     def delete(self, ref):
+        """Delete the given ref. Throws KeyError if ref doesn't exist."""
         if self.__refs == None:
             self.__cache_refs()
         self.__repository.run(['git', 'update-ref',
@@ -187,7 +476,8 @@ class Refs(object):
 
 class ObjectCache(object):
     """Cache for Python objects, for making sure that we create only one
-    Python object per git object."""
+    Python object per git object. This reduces memory consumption and
+    makes object comparison very cheap."""
     def __init__(self, create):
         self.__objects = {}
         self.__create = create
@@ -203,13 +493,40 @@ class ObjectCache(object):
 
 class RunWithEnv(object):
     def run(self, args, env = {}):
+        """Run the given command with an environment given by self.env.
+
+        @type args: list of strings
+        @param args: Command and argument vector
+        @type env: dict
+        @param env: Extra environment"""
         return run.Run(*args).env(utils.add_dict(self.env, env))
 
+class RunWithEnvCwd(RunWithEnv):
+    def run(self, args, env = {}):
+        """Run the given command with an environment given by self.env, and
+        current working directory given by self.cwd.
+
+        @type args: list of strings
+        @param args: Command and argument vector
+        @type env: dict
+        @param env: Extra environment"""
+        return RunWithEnv.run(self, args, env).cwd(self.cwd)
+    def run_in_cwd(self, args):
+        """Run the given command with an environment given by self.env and
+        self.env_in_cwd, without changing the current working
+        directory.
+
+        @type args: list of strings
+        @param args: Command and argument vector"""
+        return RunWithEnv.run(self, args, self.env_in_cwd)
+
 class Repository(RunWithEnv):
+    """Represents a git repository."""
     def __init__(self, directory):
         self.__git_dir = directory
         self.__refs = Refs(self)
-        self.__trees = ObjectCache(lambda sha1: Tree(sha1))
+        self.__blobs = ObjectCache(lambda sha1: Blob(self, sha1))
+        self.__trees = ObjectCache(lambda sha1: Tree(self, sha1))
         self.__commits = ObjectCache(lambda sha1: Commit(self, sha1))
         self.__default_index = None
         self.__default_worktree = None
@@ -224,16 +541,25 @@ class Repository(RunWithEnv):
         except run.RunException:
             raise RepositoryException('Cannot find git repository')
     @property
+    def current_branch_name(self):
+        """Return the name of the current branch."""
+        return utils.strip_prefix('refs/heads/', self.head_ref)
+    @property
     def default_index(self):
+        """An L{Index} object representing the default index file for the
+        repository."""
         if self.__default_index == None:
             self.__default_index = Index(
                 self, (os.environ.get('GIT_INDEX_FILE', None)
                        or os.path.join(self.__git_dir, 'index')))
         return self.__default_index
     def temp_index(self):
+        """Return an L{Index} object representing a new temporary index file
+        for the repository."""
         return Index(self, self.__git_dir)
     @property
     def default_worktree(self):
+        """A L{Worktree} object representing the default work tree."""
         if self.__default_worktree == None:
             path = os.environ.get('GIT_WORK_TREE', None)
             if not path:
@@ -245,6 +571,8 @@ class Repository(RunWithEnv):
         return self.__default_worktree
     @property
     def default_iw(self):
+        """An L{IndexAndWorktree} object representing the default index and
+        work tree for this repository."""
         if self.__default_iw == None:
             self.__default_iw = IndexAndWorktree(self.default_index,
                                                  self.default_worktree)
@@ -253,72 +581,120 @@ class Repository(RunWithEnv):
     refs = property(lambda self: self.__refs)
     def cat_object(self, sha1):
         return self.run(['git', 'cat-file', '-p', sha1]).raw_output()
-    def rev_parse(self, rev):
+    def rev_parse(self, rev, discard_stderr = False):
         try:
             return self.get_commit(self.run(
                     ['git', 'rev-parse', '%s^{commit}' % rev]
-                    ).output_one_line())
+                    ).discard_stderr(discard_stderr).output_one_line())
         except run.RunException:
             raise RepositoryException('%s: No such revision' % rev)
+    def get_blob(self, sha1):
+        return self.__blobs[sha1]
     def get_tree(self, sha1):
         return self.__trees[sha1]
     def get_commit(self, sha1):
         return self.__commits[sha1]
-    def commit(self, commitdata):
-        c = ['git', 'commit-tree', commitdata.tree.sha1]
-        for p in commitdata.parents:
-            c.append('-p')
-            c.append(p.sha1)
-        env = {}
-        for p, v1 in ((commitdata.author, 'AUTHOR'),
-                       (commitdata.committer, 'COMMITTER')):
-            if p != None:
-                for attr, v2 in (('name', 'NAME'), ('email', 'EMAIL'),
-                                 ('date', 'DATE')):
-                    if getattr(p, attr) != None:
-                        env['GIT_%s_%s' % (v1, v2)] = getattr(p, attr)
-        sha1 = self.run(c, env = env).raw_input(commitdata.message
-                                                ).output_one_line()
-        return self.get_commit(sha1)
+    def get_object(self, type, sha1):
+        return { Blob.typename: self.get_blob,
+                 Tree.typename: self.get_tree,
+                 Commit.typename: self.get_commit }[type](sha1)
+    def commit(self, objectdata):
+        return objectdata.commit(self)
     @property
-    def head(self):
+    def head_ref(self):
         try:
             return self.run(['git', 'symbolic-ref', '-q', 'HEAD']
                             ).output_one_line()
         except run.RunException:
             raise DetachedHeadException()
-    def set_head(self, ref, msg):
+    def set_head_ref(self, ref, msg):
         self.run(['git', 'symbolic-ref', '-m', msg, 'HEAD', ref]).no_output()
+    def get_merge_bases(self, commit1, commit2):
+        """Return a set of merge bases of two commits."""
+        sha1_list = self.run(['git', 'merge-base', '--all',
+                              commit1.sha1, commit2.sha1]).output_lines()
+        return set(self.get_commit(sha1) for sha1 in sha1_list)
+    def describe(self, commit):
+        """Use git describe --all on the given commit."""
+        return self.run(['git', 'describe', '--all', commit.sha1]
+                       ).discard_stderr().discard_exitcode().raw_output()
     def simple_merge(self, base, ours, theirs):
-        """Given three trees, tries to do an in-index merge in a temporary
-        index with a temporary index. Returns the result tree, or None if
-        the merge failed (due to conflicts)."""
-        assert isinstance(base, Tree)
-        assert isinstance(ours, Tree)
-        assert isinstance(theirs, Tree)
-
-        # Take care of the really trivial cases.
-        if base == ours:
-            return theirs
-        if base == theirs:
-            return ours
-        if ours == theirs:
-            return ours
-
         index = self.temp_index()
         try:
-            index.merge(base, ours, theirs)
+            result, index_tree = index.merge(base, ours, theirs)
+        finally:
+            index.delete()
+        return result
+    def apply(self, tree, patch_text, quiet):
+        """Given a L{Tree} and a patch, will either return the new L{Tree}
+        that results when the patch is applied, or None if the patch
+        couldn't be applied."""
+        assert isinstance(tree, Tree)
+        if not patch_text:
+            return tree
+        index = self.temp_index()
+        try:
+            index.read_tree(tree)
             try:
+                index.apply(patch_text, quiet)
                 return index.write_tree()
             except MergeException:
                 return None
         finally:
             index.delete()
+    def diff_tree(self, t1, t2, diff_opts, binary = True):
+        """Given two L{Tree}s C{t1} and C{t2}, return the patch that takes
+        C{t1} to C{t2}.
+
+        @type diff_opts: list of strings
+        @param diff_opts: Extra diff options
+        @rtype: String
+        @return: Patch text"""
+        assert isinstance(t1, Tree)
+        assert isinstance(t2, Tree)
+        diff_opts = list(diff_opts)
+        if binary and not '--binary' in diff_opts:
+            diff_opts.append('--binary')
+        return self.run(['git', 'diff-tree', '-p'] + diff_opts
+                        + [t1.sha1, t2.sha1]).raw_output()
+    def diff_tree_files(self, t1, t2):
+        """Given two L{Tree}s C{t1} and C{t2}, iterate over all files for
+        which they differ. For each file, yield a tuple with the old
+        file mode, the new file mode, the old blob, the new blob, the
+        status, the old filename, and the new filename. Except in case
+        of a copy or a rename, the old and new filenames are
+        identical."""
+        assert isinstance(t1, Tree)
+        assert isinstance(t2, Tree)
+        i = iter(self.run(['git', 'diff-tree', '-r', '-z'] + [t1.sha1, t2.sha1]
+                          ).raw_output().split('\0'))
+        try:
+            while True:
+                x = i.next()
+                if not x:
+                    continue
+                omode, nmode, osha1, nsha1, status = x[1:].split(' ')
+                fn1 = i.next()
+                if status[0] in ['C', 'R']:
+                    fn2 = i.next()
+                else:
+                    fn2 = fn1
+                yield (omode, nmode, self.get_blob(osha1),
+                       self.get_blob(nsha1), status, fn1, fn2)
+        except StopIteration:
+            pass
 
 class MergeException(exception.StgException):
-    pass
+    """Exception raised when a merge fails for some reason."""
+
+class MergeConflictException(MergeException):
+    """Exception raised when a merge fails due to conflicts."""
+    def __init__(self, conflicts):
+        MergeException.__init__(self)
+        self.conflicts = conflicts
 
 class Index(RunWithEnv):
+    """Represents a git index file."""
     def __init__(self, repository, filename):
         self.__repository = repository
         if os.path.isdir(filename):
@@ -333,43 +709,121 @@ class Index(RunWithEnv):
     def read_tree(self, tree):
         self.run(['git', 'read-tree', tree.sha1]).no_output()
     def write_tree(self):
+        """Write the index contents to the repository.
+        @return: The resulting L{Tree}
+        @rtype: L{Tree}"""
         try:
             return self.__repository.get_tree(
                 self.run(['git', 'write-tree']).discard_stderr(
                     ).output_one_line())
         except run.RunException:
             raise MergeException('Conflicting merge')
-    def is_clean(self):
+    def is_clean(self, tree):
+        """Check whether the index is clean relative to the given treeish."""
         try:
-            self.run(['git', 'update-index', '--refresh']).discard_output()
+            self.run(['git', 'diff-index', '--quiet', '--cached', tree.sha1]
+                    ).discard_output()
         except run.RunException:
             return False
         else:
             return True
-    def merge(self, base, ours, theirs):
-        """In-index merge, no worktree involved."""
-        self.run(['git', 'read-tree', '-m', '-i', '--aggressive',
-                  base.sha1, ours.sha1, theirs.sha1]).no_output()
+    def apply(self, patch_text, quiet):
+        """In-index patch application, no worktree involved."""
+        try:
+            r = self.run(['git', 'apply', '--cached']).raw_input(patch_text)
+            if quiet:
+                r = r.discard_stderr()
+            r.no_output()
+        except run.RunException:
+            raise MergeException('Patch does not apply cleanly')
+    def apply_treediff(self, tree1, tree2, quiet):
+        """Apply the diff from C{tree1} to C{tree2} to the index."""
+        # Passing --full-index here is necessary to support binary
+        # files. It is also sufficient, since the repository already
+        # contains all involved objects; in other words, we don't have
+        # to use --binary.
+        self.apply(self.__repository.diff_tree(tree1, tree2, ['--full-index']),
+                   quiet)
+    def merge(self, base, ours, theirs, current = None):
+        """Use the index (and only the index) to do a 3-way merge of the
+        L{Tree}s C{base}, C{ours} and C{theirs}. The merge will either
+        succeed (in which case the first half of the return value is
+        the resulting tree) or fail cleanly (in which case the first
+        half of the return value is C{None}).
+
+        If C{current} is given (and not C{None}), it is assumed to be
+        the L{Tree} currently stored in the index; this information is
+        used to avoid having to read the right tree (either of C{ours}
+        and C{theirs}) into the index if it's already there. The
+        second half of the return value is the tree now stored in the
+        index, or C{None} if unknown. If the merge succeeded, this is
+        often the merge result."""
+        assert isinstance(base, Tree)
+        assert isinstance(ours, Tree)
+        assert isinstance(theirs, Tree)
+        assert current == None or isinstance(current, Tree)
+
+        # Take care of the really trivial cases.
+        if base == ours:
+            return (theirs, current)
+        if base == theirs:
+            return (ours, current)
+        if ours == theirs:
+            return (ours, current)
+
+        if current == theirs:
+            # Swap the trees. It doesn't matter since merging is
+            # symmetric, and will allow us to avoid the read_tree()
+            # call below.
+            ours, theirs = theirs, ours
+        if current != ours:
+            self.read_tree(ours)
+        try:
+            self.apply_treediff(base, theirs, quiet = True)
+            result = self.write_tree()
+            return (result, result)
+        except MergeException:
+            return (None, ours)
     def delete(self):
         if os.path.isfile(self.__filename):
             os.remove(self.__filename)
+    def conflicts(self):
+        """The set of conflicting paths."""
+        paths = set()
+        for line in self.run(['git', 'ls-files', '-z', '--unmerged']
+                             ).raw_output().split('\0')[:-1]:
+            stat, path = line.split('\t', 1)
+            paths.add(path)
+        return paths
 
 class Worktree(object):
+    """Represents a git worktree (that is, a checked-out file tree)."""
     def __init__(self, directory):
         self.__directory = directory
-    env = property(lambda self: { 'GIT_WORK_TREE': self.__directory })
+    env = property(lambda self: { 'GIT_WORK_TREE': '.' })
+    env_in_cwd = property(lambda self: { 'GIT_WORK_TREE': self.directory })
     directory = property(lambda self: self.__directory)
 
 class CheckoutException(exception.StgException):
-    pass
+    """Exception raised when a checkout fails."""
 
-class IndexAndWorktree(RunWithEnv):
+class IndexAndWorktree(RunWithEnvCwd):
+    """Represents a git index and a worktree. Anything that an index or
+    worktree can do on their own are handled by the L{Index} and
+    L{Worktree} classes; this class concerns itself with the
+    operations that require both."""
     def __init__(self, index, worktree):
         self.__index = index
         self.__worktree = worktree
     index = property(lambda self: self.__index)
     env = property(lambda self: utils.add_dict(self.__index.env,
                                                self.__worktree.env))
+    env_in_cwd = property(lambda self: self.__worktree.env_in_cwd)
+    cwd = property(lambda self: self.__worktree.directory)
+    def checkout_hard(self, tree):
+        assert isinstance(tree, Tree)
+        self.run(['git', 'read-tree', '--reset', '-u', tree.sha1]
+                 ).discard_output()
     def checkout(self, old_tree, new_tree):
         # TODO: Optionally do a 3-way instead of doing nothing when we
         # have a problem. Or maybe we should stash changes in a patch?
@@ -379,24 +833,119 @@ class IndexAndWorktree(RunWithEnv):
             self.run(['git', 'read-tree', '-u', '-m',
                       '--exclude-per-directory=.gitignore',
                       old_tree.sha1, new_tree.sha1]
-                     ).cwd(self.__worktree.directory).discard_output()
+                     ).discard_output()
         except run.RunException:
             raise CheckoutException('Index/workdir dirty')
-    def merge(self, base, ours, theirs):
+    def merge(self, base, ours, theirs, interactive = False):
         assert isinstance(base, Tree)
         assert isinstance(ours, Tree)
         assert isinstance(theirs, Tree)
         try:
-            self.run(['git', 'merge-recursive', base.sha1, '--', ours.sha1,
-                      theirs.sha1],
-                     env = { 'GITHEAD_%s' % base.sha1: 'ancestor',
-                             'GITHEAD_%s' % ours.sha1: 'current',
-                             'GITHEAD_%s' % theirs.sha1: 'patched'}
-                     ).cwd(self.__worktree.directory).discard_output()
+            r = self.run(['git', 'merge-recursive', base.sha1, '--', ours.sha1,
+                          theirs.sha1],
+                         env = { 'GITHEAD_%s' % base.sha1: 'ancestor',
+                                 'GITHEAD_%s' % ours.sha1: 'current',
+                                 'GITHEAD_%s' % theirs.sha1: 'patched'})
+            r.returns([0, 1])
+            output = r.output_lines()
+            if r.exitcode:
+                # There were conflicts
+                if interactive:
+                    self.mergetool()
+                else:
+                    conflicts = [l for l in output if l.startswith('CONFLICT')]
+                    raise MergeConflictException(conflicts)
         except run.RunException, e:
             raise MergeException('Index/worktree dirty')
-    def changed_files(self):
-        return self.run(['git', 'diff-files', '--name-only']).output_lines()
-    def update_index(self, files):
-        self.run(['git', 'update-index', '--remove', '-z', '--stdin']
-                 ).input_nulterm(files).discard_output()
+    def mergetool(self, files = ()):
+        """Invoke 'git mergetool' on the current IndexAndWorktree to resolve
+        any outstanding conflicts. If 'not files', all the files in an
+        unmerged state will be processed."""
+        self.run(['git', 'mergetool'] + list(files)).returns([0, 1]).run()
+        # check for unmerged entries (prepend 'CONFLICT ' for consistency with
+        # merge())
+        conflicts = ['CONFLICT ' + f for f in self.index.conflicts()]
+        if conflicts:
+            raise MergeConflictException(conflicts)
+    def changed_files(self, tree, pathlimits = []):
+        """Return the set of files in the worktree that have changed with
+        respect to C{tree}. The listing is optionally restricted to
+        those files that match any of the path limiters given.
+
+        The path limiters are relative to the current working
+        directory; the returned file names are relative to the
+        repository root."""
+        assert isinstance(tree, Tree)
+        return set(self.run_in_cwd(
+                ['git', 'diff-index', tree.sha1, '--name-only', '-z', '--']
+                + list(pathlimits)).raw_output().split('\0')[:-1])
+    def update_index(self, paths):
+        """Update the index with files from the worktree. C{paths} is an
+        iterable of paths relative to the root of the repository."""
+        cmd = ['git', 'update-index', '--remove']
+        self.run(cmd + ['-z', '--stdin']
+                 ).input_nulterm(paths).discard_output()
+    def worktree_clean(self):
+        """Check whether the worktree is clean relative to index."""
+        try:
+            self.run(['git', 'update-index', '--refresh']).discard_output()
+        except run.RunException:
+            return False
+        else:
+            return True
+
+class Branch(object):
+    """Represents a Git branch."""
+    def __init__(self, repository, name):
+        self.__repository = repository
+        self.__name = name
+        try:
+            self.head
+        except KeyError:
+            raise BranchException('%s: no such branch' % name)
+
+    name = property(lambda self: self.__name)
+    repository = property(lambda self: self.__repository)
+
+    def __ref(self):
+        return 'refs/heads/%s' % self.__name
+    @property
+    def head(self):
+        return self.__repository.refs.get(self.__ref())
+    def set_head(self, commit, msg):
+        self.__repository.refs.set(self.__ref(), commit, msg)
+
+    def set_parent_remote(self, name):
+        value = config.set('branch.%s.remote' % self.__name, name)
+    def set_parent_branch(self, name):
+        if config.get('branch.%s.remote' % self.__name):
+            # Never set merge if remote is not set to avoid
+            # possibly-erroneous lookups into 'origin'
+            config.set('branch.%s.merge' % self.__name, name)
+
+    @classmethod
+    def create(cls, repository, name, create_at = None):
+        """Create a new Git branch and return the corresponding
+        L{Branch} object."""
+        try:
+            branch = cls(repository, name)
+        except BranchException:
+            branch = None
+        if branch:
+            raise BranchException('%s: branch already exists' % name)
+
+        cmd = ['git', 'branch']
+        if create_at:
+            cmd.append(create_at.sha1)
+        repository.run(['git', 'branch', create_at.sha1]).discard_output()
+
+        return cls(repository, name)
+
+def diffstat(diff):
+    """Return the diffstat of the supplied diff."""
+    return run.Run('git', 'apply', '--stat', '--summary'
+                   ).raw_input(diff).raw_output()
+
+def clone(remote, local):
+    """Clone a remote repository using 'git clone'."""
+    run.Run('git', 'clone', remote, local).run()