chiark / gitweb /
2386e279f6defd042bb90968ff3087303eac7d5c
[stgit] / stgit / lib / git.py
1 """A Python class hierarchy wrapping a git repository and its
2 contents."""
3
4 import os, os.path, re
5 from datetime import datetime, timedelta, tzinfo
6
7 from stgit import exception, run, utils
8 from stgit.config import config
9
10 class Immutable(object):
11     """I{Immutable} objects cannot be modified once created. Any
12     modification methods will return a new object, leaving the
13     original object as it was.
14
15     The reason for this is that we want to be able to represent git
16     objects, which are immutable, and want to be able to create new
17     git objects that are just slight modifications of other git
18     objects. (Such as, for example, modifying the commit message of a
19     commit object while leaving the rest of it intact. This involves
20     creating a whole new commit object that's exactly like the old one
21     except for the commit message.)
22
23     The L{Immutable} class doesn't actually enforce immutability --
24     that is up to the individual immutable subclasses. It just serves
25     as documentation."""
26
27 class RepositoryException(exception.StgException):
28     """Base class for all exceptions due to failed L{Repository}
29     operations."""
30
31 class BranchException(exception.StgException):
32     """Exception raised by failed L{Branch} operations."""
33
34 class DateException(exception.StgException):
35     """Exception raised when a date+time string could not be parsed."""
36     def __init__(self, string, type):
37         exception.StgException.__init__(
38             self, '"%s" is not a valid %s' % (string, type))
39
40 class DetachedHeadException(RepositoryException):
41     """Exception raised when HEAD is detached (that is, there is no
42     current branch)."""
43     def __init__(self):
44         RepositoryException.__init__(self, 'Not on any branch')
45
46 class Repr(object):
47     """Utility class that defines C{__reps__} in terms of C{__str__}."""
48     def __repr__(self):
49         return str(self)
50
51 class NoValue(object):
52     """A handy default value that is guaranteed to be distinct from any
53     real argument value."""
54     pass
55
56 def make_defaults(defaults):
57     def d(val, attr, default_fun = lambda: None):
58         if val != NoValue:
59             return val
60         elif defaults != NoValue:
61             return getattr(defaults, attr)
62         else:
63             return default_fun()
64     return d
65
66 class TimeZone(tzinfo, Repr):
67     """A simple time zone class for static offsets from UTC. (We have to
68     define our own since Python's standard library doesn't define any
69     time zone classes.)"""
70     def __init__(self, tzstring):
71         m = re.match(r'^([+-])(\d{2}):?(\d{2})$', tzstring)
72         if not m:
73             raise DateException(tzstring, 'time zone')
74         sign = int(m.group(1) + '1')
75         try:
76             self.__offset = timedelta(hours = sign*int(m.group(2)),
77                                       minutes = sign*int(m.group(3)))
78         except OverflowError:
79             raise DateException(tzstring, 'time zone')
80         self.__name = tzstring
81     def utcoffset(self, dt):
82         return self.__offset
83     def tzname(self, dt):
84         return self.__name
85     def dst(self, dt):
86         return timedelta(0)
87     def __str__(self):
88         return self.__name
89
90 class Date(Immutable, Repr):
91     """Represents a timestamp used in git commits."""
92     def __init__(self, datestring):
93         # Try git-formatted date.
94         m = re.match(r'^(\d+)\s+([+-]\d\d:?\d\d)$', datestring)
95         if m:
96             try:
97                 self.__time = datetime.fromtimestamp(int(m.group(1)),
98                                                      TimeZone(m.group(2)))
99             except ValueError:
100                 raise DateException(datestring, 'date')
101             return
102
103         # Try iso-formatted date.
104         m = re.match(r'^(\d{4})-(\d{2})-(\d{2})\s+(\d{2}):(\d{2}):(\d{2})\s+'
105                      + r'([+-]\d\d:?\d\d)$', datestring)
106         if m:
107             try:
108                 self.__time = datetime(
109                     *[int(m.group(i + 1)) for i in xrange(6)],
110                     **{'tzinfo': TimeZone(m.group(7))})
111             except ValueError:
112                 raise DateException(datestring, 'date')
113             return
114
115         raise DateException(datestring, 'date')
116     def __str__(self):
117         return self.isoformat()
118     def isoformat(self):
119         """Human-friendly ISO 8601 format."""
120         return '%s %s' % (self.__time.replace(tzinfo = None).isoformat(' '),
121                           self.__time.tzinfo)
122     @classmethod
123     def maybe(cls, datestring):
124         """Return a new object initialized with the argument if it contains a
125         value (otherwise, just return the argument)."""
126         if datestring in [None, NoValue]:
127             return datestring
128         return cls(datestring)
129
130 class Person(Immutable, Repr):
131     """Represents an author or committer in a git commit object. Contains
132     name, email and timestamp."""
133     def __init__(self, name = NoValue, email = NoValue,
134                  date = NoValue, defaults = NoValue):
135         d = make_defaults(defaults)
136         self.__name = d(name, 'name')
137         self.__email = d(email, 'email')
138         self.__date = d(date, 'date')
139         assert isinstance(self.__date, Date) or self.__date in [None, NoValue]
140     name = property(lambda self: self.__name)
141     email = property(lambda self: self.__email)
142     date = property(lambda self: self.__date)
143     def set_name(self, name):
144         return type(self)(name = name, defaults = self)
145     def set_email(self, email):
146         return type(self)(email = email, defaults = self)
147     def set_date(self, date):
148         return type(self)(date = date, defaults = self)
149     def __str__(self):
150         return '%s <%s> %s' % (self.name, self.email, self.date)
151     @classmethod
152     def parse(cls, s):
153         m = re.match(r'^([^<]*)<([^>]*)>\s+(\d+\s+[+-]\d{4})$', s)
154         assert m
155         name = m.group(1).strip()
156         email = m.group(2)
157         date = Date(m.group(3))
158         return cls(name, email, date)
159     @classmethod
160     def user(cls):
161         if not hasattr(cls, '__user'):
162             cls.__user = cls(name = config.get('user.name'),
163                              email = config.get('user.email'))
164         return cls.__user
165     @classmethod
166     def author(cls):
167         if not hasattr(cls, '__author'):
168             cls.__author = cls(
169                 name = os.environ.get('GIT_AUTHOR_NAME', NoValue),
170                 email = os.environ.get('GIT_AUTHOR_EMAIL', NoValue),
171                 date = Date.maybe(os.environ.get('GIT_AUTHOR_DATE', NoValue)),
172                 defaults = cls.user())
173         return cls.__author
174     @classmethod
175     def committer(cls):
176         if not hasattr(cls, '__committer'):
177             cls.__committer = cls(
178                 name = os.environ.get('GIT_COMMITTER_NAME', NoValue),
179                 email = os.environ.get('GIT_COMMITTER_EMAIL', NoValue),
180                 date = Date.maybe(
181                     os.environ.get('GIT_COMMITTER_DATE', NoValue)),
182                 defaults = cls.user())
183         return cls.__committer
184
185 class GitObject(Immutable, Repr):
186     """Base class for all git objects. One git object is represented by at
187     most one C{GitObject}, which makes it possible to compare them
188     using normal Python object comparison; it also ensures we don't
189     waste more memory than necessary."""
190
191 class BlobData(Immutable, Repr):
192     """Represents the data contents of a git blob object."""
193     def __init__(self, string):
194         self.__string = str(string)
195     str = property(lambda self: self.__string)
196     def commit(self, repository):
197         """Commit the blob.
198         @return: The committed blob
199         @rtype: L{Blob}"""
200         sha1 = repository.run(['git', 'hash-object', '-w', '--stdin']
201                               ).raw_input(self.str).output_one_line()
202         return repository.get_blob(sha1)
203
204 class Blob(GitObject):
205     """Represents a git blob object. All the actual data contents of the
206     blob object is stored in the L{data} member, which is a
207     L{BlobData} object."""
208     typename = 'blob'
209     default_perm = '100644'
210     def __init__(self, repository, sha1):
211         self.__repository = repository
212         self.__sha1 = sha1
213     sha1 = property(lambda self: self.__sha1)
214     def __str__(self):
215         return 'Blob<%s>' % self.sha1
216     @property
217     def data(self):
218         return BlobData(self.__repository.cat_object(self.sha1))
219
220 class ImmutableDict(dict):
221     """A dictionary that cannot be modified once it's been created."""
222     def error(*args, **kwargs):
223         raise TypeError('Cannot modify immutable dict')
224     __delitem__ = error
225     __setitem__ = error
226     clear = error
227     pop = error
228     popitem = error
229     setdefault = error
230     update = error
231
232 class TreeData(Immutable, Repr):
233     """Represents the data contents of a git tree object."""
234     @staticmethod
235     def __x(po):
236         if isinstance(po, GitObject):
237             perm, object = po.default_perm, po
238         else:
239             perm, object = po
240         return perm, object
241     def __init__(self, entries):
242         """Create a new L{TreeData} object from the given mapping from names
243         (strings) to either (I{permission}, I{object}) tuples or just
244         objects."""
245         self.__entries = ImmutableDict((name, self.__x(po))
246                                        for (name, po) in entries.iteritems())
247     entries = property(lambda self: self.__entries)
248     """Map from name to (I{permission}, I{object}) tuple."""
249     def set_entry(self, name, po):
250         """Create a new L{TreeData} object identical to this one, except that
251         it maps C{name} to C{po}.
252
253         @param name: Name of the changed mapping
254         @type name: C{str}
255         @param po: Value of the changed mapping
256         @type po: L{Blob} or L{Tree} or (C{str}, L{Blob} or L{Tree})
257         @return: The new L{TreeData} object
258         @rtype: L{TreeData}"""
259         e = dict(self.entries)
260         e[name] = self.__x(po)
261         return type(self)(e)
262     def del_entry(self, name):
263         """Create a new L{TreeData} object identical to this one, except that
264         it doesn't map C{name} to anything.
265
266         @param name: Name of the deleted mapping
267         @type name: C{str}
268         @return: The new L{TreeData} object
269         @rtype: L{TreeData}"""
270         e = dict(self.entries)
271         del e[name]
272         return type(self)(e)
273     def commit(self, repository):
274         """Commit the tree.
275         @return: The committed tree
276         @rtype: L{Tree}"""
277         listing = ''.join(
278             '%s %s %s\t%s\0' % (mode, obj.typename, obj.sha1, name)
279             for (name, (mode, obj)) in self.entries.iteritems())
280         sha1 = repository.run(['git', 'mktree', '-z']
281                               ).raw_input(listing).output_one_line()
282         return repository.get_tree(sha1)
283     @classmethod
284     def parse(cls, repository, s):
285         """Parse a raw git tree description.
286
287         @return: A new L{TreeData} object
288         @rtype: L{TreeData}"""
289         entries = {}
290         for line in s.split('\0')[:-1]:
291             m = re.match(r'^([0-7]{6}) ([a-z]+) ([0-9a-f]{40})\t(.*)$', line)
292             assert m
293             perm, type, sha1, name = m.groups()
294             entries[name] = (perm, repository.get_object(type, sha1))
295         return cls(entries)
296
297 class Tree(GitObject):
298     """Represents a git tree object. All the actual data contents of the
299     tree object is stored in the L{data} member, which is a
300     L{TreeData} object."""
301     typename = 'tree'
302     default_perm = '040000'
303     def __init__(self, repository, sha1):
304         self.__sha1 = sha1
305         self.__repository = repository
306         self.__data = None
307     sha1 = property(lambda self: self.__sha1)
308     @property
309     def data(self):
310         if self.__data == None:
311             self.__data = TreeData.parse(
312                 self.__repository,
313                 self.__repository.run(['git', 'ls-tree', '-z', self.sha1]
314                                       ).raw_output())
315         return self.__data
316     def __str__(self):
317         return 'Tree<sha1: %s>' % self.sha1
318
319 class CommitData(Immutable, Repr):
320     """Represents the data contents of a git commit object."""
321     def __init__(self, tree = NoValue, parents = NoValue, author = NoValue,
322                  committer = NoValue, message = NoValue, defaults = NoValue):
323         d = make_defaults(defaults)
324         self.__tree = d(tree, 'tree')
325         self.__parents = d(parents, 'parents')
326         self.__author = d(author, 'author', Person.author)
327         self.__committer = d(committer, 'committer', Person.committer)
328         self.__message = d(message, 'message')
329     tree = property(lambda self: self.__tree)
330     parents = property(lambda self: self.__parents)
331     @property
332     def parent(self):
333         assert len(self.__parents) == 1
334         return self.__parents[0]
335     author = property(lambda self: self.__author)
336     committer = property(lambda self: self.__committer)
337     message = property(lambda self: self.__message)
338     def set_tree(self, tree):
339         return type(self)(tree = tree, defaults = self)
340     def set_parents(self, parents):
341         return type(self)(parents = parents, defaults = self)
342     def add_parent(self, parent):
343         return type(self)(parents = list(self.parents or []) + [parent],
344                           defaults = self)
345     def set_parent(self, parent):
346         return self.set_parents([parent])
347     def set_author(self, author):
348         return type(self)(author = author, defaults = self)
349     def set_committer(self, committer):
350         return type(self)(committer = committer, defaults = self)
351     def set_message(self, message):
352         return type(self)(message = message, defaults = self)
353     def is_nochange(self):
354         return len(self.parents) == 1 and self.tree == self.parent.data.tree
355     def __str__(self):
356         if self.tree == None:
357             tree = None
358         else:
359             tree = self.tree.sha1
360         if self.parents == None:
361             parents = None
362         else:
363             parents = [p.sha1 for p in self.parents]
364         return ('CommitData<tree: %s, parents: %s, author: %s,'
365                 ' committer: %s, message: "%s">'
366                 ) % (tree, parents, self.author, self.committer, self.message)
367     def commit(self, repository):
368         """Commit the commit.
369         @return: The committed commit
370         @rtype: L{Commit}"""
371         c = ['git', 'commit-tree', self.tree.sha1]
372         for p in self.parents:
373             c.append('-p')
374             c.append(p.sha1)
375         env = {}
376         for p, v1 in ((self.author, 'AUTHOR'),
377                        (self.committer, 'COMMITTER')):
378             if p != None:
379                 for attr, v2 in (('name', 'NAME'), ('email', 'EMAIL'),
380                                  ('date', 'DATE')):
381                     if getattr(p, attr) != None:
382                         env['GIT_%s_%s' % (v1, v2)] = str(getattr(p, attr))
383         sha1 = repository.run(c, env = env).raw_input(self.message
384                                                       ).output_one_line()
385         return repository.get_commit(sha1)
386     @classmethod
387     def parse(cls, repository, s):
388         """Parse a raw git commit description.
389         @return: A new L{CommitData} object
390         @rtype: L{CommitData}"""
391         cd = cls(parents = [])
392         lines = list(s.splitlines(True))
393         for i in xrange(len(lines)):
394             line = lines[i].strip()
395             if not line:
396                 return cd.set_message(''.join(lines[i+1:]))
397             key, value = line.split(None, 1)
398             if key == 'tree':
399                 cd = cd.set_tree(repository.get_tree(value))
400             elif key == 'parent':
401                 cd = cd.add_parent(repository.get_commit(value))
402             elif key == 'author':
403                 cd = cd.set_author(Person.parse(value))
404             elif key == 'committer':
405                 cd = cd.set_committer(Person.parse(value))
406             else:
407                 assert False
408         assert False
409
410 class Commit(GitObject):
411     """Represents a git commit object. All the actual data contents of the
412     commit object is stored in the L{data} member, which is a
413     L{CommitData} object."""
414     typename = 'commit'
415     def __init__(self, repository, sha1):
416         self.__sha1 = sha1
417         self.__repository = repository
418         self.__data = None
419     sha1 = property(lambda self: self.__sha1)
420     @property
421     def data(self):
422         if self.__data == None:
423             self.__data = CommitData.parse(
424                 self.__repository,
425                 self.__repository.cat_object(self.sha1))
426         return self.__data
427     def __str__(self):
428         return 'Commit<sha1: %s, data: %s>' % (self.sha1, self.__data)
429
430 class Refs(object):
431     """Accessor for the refs stored in a git repository. Will
432     transparently cache the values of all refs."""
433     def __init__(self, repository):
434         self.__repository = repository
435         self.__refs = None
436     def __cache_refs(self):
437         """(Re-)Build the cache of all refs in the repository."""
438         self.__refs = {}
439         for line in self.__repository.run(['git', 'show-ref']).output_lines():
440             m = re.match(r'^([0-9a-f]{40})\s+(\S+)$', line)
441             sha1, ref = m.groups()
442             self.__refs[ref] = sha1
443     def get(self, ref):
444         """Get the Commit the given ref points to. Throws KeyError if ref
445         doesn't exist."""
446         if self.__refs == None:
447             self.__cache_refs()
448         return self.__repository.get_commit(self.__refs[ref])
449     def exists(self, ref):
450         """Check if the given ref exists."""
451         try:
452             self.get(ref)
453         except KeyError:
454             return False
455         else:
456             return True
457     def set(self, ref, commit, msg):
458         """Write the sha1 of the given Commit to the ref. The ref may or may
459         not already exist."""
460         if self.__refs == None:
461             self.__cache_refs()
462         old_sha1 = self.__refs.get(ref, '0'*40)
463         new_sha1 = commit.sha1
464         if old_sha1 != new_sha1:
465             self.__repository.run(['git', 'update-ref', '-m', msg,
466                                    ref, new_sha1, old_sha1]).no_output()
467             self.__refs[ref] = new_sha1
468     def delete(self, ref):
469         """Delete the given ref. Throws KeyError if ref doesn't exist."""
470         if self.__refs == None:
471             self.__cache_refs()
472         self.__repository.run(['git', 'update-ref',
473                                '-d', ref, self.__refs[ref]]).no_output()
474         del self.__refs[ref]
475
476 class ObjectCache(object):
477     """Cache for Python objects, for making sure that we create only one
478     Python object per git object. This reduces memory consumption and
479     makes object comparison very cheap."""
480     def __init__(self, create):
481         self.__objects = {}
482         self.__create = create
483     def __getitem__(self, name):
484         if not name in self.__objects:
485             self.__objects[name] = self.__create(name)
486         return self.__objects[name]
487     def __contains__(self, name):
488         return name in self.__objects
489     def __setitem__(self, name, val):
490         assert not name in self.__objects
491         self.__objects[name] = val
492
493 class RunWithEnv(object):
494     def run(self, args, env = {}):
495         """Run the given command with an environment given by self.env.
496
497         @type args: list of strings
498         @param args: Command and argument vector
499         @type env: dict
500         @param env: Extra environment"""
501         return run.Run(*args).env(utils.add_dict(self.env, env))
502
503 class RunWithEnvCwd(RunWithEnv):
504     def run(self, args, env = {}):
505         """Run the given command with an environment given by self.env, and
506         current working directory given by self.cwd.
507
508         @type args: list of strings
509         @param args: Command and argument vector
510         @type env: dict
511         @param env: Extra environment"""
512         return RunWithEnv.run(self, args, env).cwd(self.cwd)
513
514 class Repository(RunWithEnv):
515     """Represents a git repository."""
516     def __init__(self, directory):
517         self.__git_dir = directory
518         self.__refs = Refs(self)
519         self.__blobs = ObjectCache(lambda sha1: Blob(self, sha1))
520         self.__trees = ObjectCache(lambda sha1: Tree(self, sha1))
521         self.__commits = ObjectCache(lambda sha1: Commit(self, sha1))
522         self.__default_index = None
523         self.__default_worktree = None
524         self.__default_iw = None
525     env = property(lambda self: { 'GIT_DIR': self.__git_dir })
526     @classmethod
527     def default(cls):
528         """Return the default repository."""
529         try:
530             return cls(run.Run('git', 'rev-parse', '--git-dir'
531                                ).output_one_line())
532         except run.RunException:
533             raise RepositoryException('Cannot find git repository')
534     @property
535     def current_branch_name(self):
536         """Return the name of the current branch."""
537         return utils.strip_prefix('refs/heads/', self.head_ref)
538     @property
539     def default_index(self):
540         """An L{Index} object representing the default index file for the
541         repository."""
542         if self.__default_index == None:
543             self.__default_index = Index(
544                 self, (os.environ.get('GIT_INDEX_FILE', None)
545                        or os.path.join(self.__git_dir, 'index')))
546         return self.__default_index
547     def temp_index(self):
548         """Return an L{Index} object representing a new temporary index file
549         for the repository."""
550         return Index(self, self.__git_dir)
551     @property
552     def default_worktree(self):
553         """A L{Worktree} object representing the default work tree."""
554         if self.__default_worktree == None:
555             path = os.environ.get('GIT_WORK_TREE', None)
556             if not path:
557                 o = run.Run('git', 'rev-parse', '--show-cdup').output_lines()
558                 o = o or ['.']
559                 assert len(o) == 1
560                 path = o[0]
561             self.__default_worktree = Worktree(path)
562         return self.__default_worktree
563     @property
564     def default_iw(self):
565         """An L{IndexAndWorktree} object representing the default index and
566         work tree for this repository."""
567         if self.__default_iw == None:
568             self.__default_iw = IndexAndWorktree(self.default_index,
569                                                  self.default_worktree)
570         return self.__default_iw
571     directory = property(lambda self: self.__git_dir)
572     refs = property(lambda self: self.__refs)
573     def cat_object(self, sha1):
574         return self.run(['git', 'cat-file', '-p', sha1]).raw_output()
575     def rev_parse(self, rev, discard_stderr = False):
576         try:
577             return self.get_commit(self.run(
578                     ['git', 'rev-parse', '%s^{commit}' % rev]
579                     ).discard_stderr(discard_stderr).output_one_line())
580         except run.RunException:
581             raise RepositoryException('%s: No such revision' % rev)
582     def get_blob(self, sha1):
583         return self.__blobs[sha1]
584     def get_tree(self, sha1):
585         return self.__trees[sha1]
586     def get_commit(self, sha1):
587         return self.__commits[sha1]
588     def get_object(self, type, sha1):
589         return { Blob.typename: self.get_blob,
590                  Tree.typename: self.get_tree,
591                  Commit.typename: self.get_commit }[type](sha1)
592     def commit(self, objectdata):
593         return objectdata.commit(self)
594     @property
595     def head_ref(self):
596         try:
597             return self.run(['git', 'symbolic-ref', '-q', 'HEAD']
598                             ).output_one_line()
599         except run.RunException:
600             raise DetachedHeadException()
601     def set_head_ref(self, ref, msg):
602         self.run(['git', 'symbolic-ref', '-m', msg, 'HEAD', ref]).no_output()
603     def simple_merge(self, base, ours, theirs):
604         index = self.temp_index()
605         try:
606             result, index_tree = index.merge(base, ours, theirs)
607         finally:
608             index.delete()
609         return result
610     def apply(self, tree, patch_text, quiet):
611         """Given a L{Tree} and a patch, will either return the new L{Tree}
612         that results when the patch is applied, or None if the patch
613         couldn't be applied."""
614         assert isinstance(tree, Tree)
615         if not patch_text:
616             return tree
617         index = self.temp_index()
618         try:
619             index.read_tree(tree)
620             try:
621                 index.apply(patch_text, quiet)
622                 return index.write_tree()
623             except MergeException:
624                 return None
625         finally:
626             index.delete()
627     def diff_tree(self, t1, t2, diff_opts):
628         """Given two L{Tree}s C{t1} and C{t2}, return the patch that takes
629         C{t1} to C{t2}.
630
631         @type diff_opts: list of strings
632         @param diff_opts: Extra diff options
633         @rtype: String
634         @return: Patch text"""
635         assert isinstance(t1, Tree)
636         assert isinstance(t2, Tree)
637         return self.run(['git', 'diff-tree', '-p'] + list(diff_opts)
638                         + [t1.sha1, t2.sha1]).raw_output()
639
640 class MergeException(exception.StgException):
641     """Exception raised when a merge fails for some reason."""
642
643 class MergeConflictException(MergeException):
644     """Exception raised when a merge fails due to conflicts."""
645
646 class Index(RunWithEnv):
647     """Represents a git index file."""
648     def __init__(self, repository, filename):
649         self.__repository = repository
650         if os.path.isdir(filename):
651             # Create a temp index in the given directory.
652             self.__filename = os.path.join(
653                 filename, 'index.temp-%d-%x' % (os.getpid(), id(self)))
654             self.delete()
655         else:
656             self.__filename = filename
657     env = property(lambda self: utils.add_dict(
658             self.__repository.env, { 'GIT_INDEX_FILE': self.__filename }))
659     def read_tree(self, tree):
660         self.run(['git', 'read-tree', tree.sha1]).no_output()
661     def write_tree(self):
662         try:
663             return self.__repository.get_tree(
664                 self.run(['git', 'write-tree']).discard_stderr(
665                     ).output_one_line())
666         except run.RunException:
667             raise MergeException('Conflicting merge')
668     def is_clean(self):
669         try:
670             self.run(['git', 'update-index', '--refresh']).discard_output()
671         except run.RunException:
672             return False
673         else:
674             return True
675     def apply(self, patch_text, quiet):
676         """In-index patch application, no worktree involved."""
677         try:
678             r = self.run(['git', 'apply', '--cached']).raw_input(patch_text)
679             if quiet:
680                 r = r.discard_stderr()
681             r.no_output()
682         except run.RunException:
683             raise MergeException('Patch does not apply cleanly')
684     def apply_treediff(self, tree1, tree2, quiet):
685         """Apply the diff from C{tree1} to C{tree2} to the index."""
686         # Passing --full-index here is necessary to support binary
687         # files. It is also sufficient, since the repository already
688         # contains all involved objects; in other words, we don't have
689         # to use --binary.
690         self.apply(self.__repository.diff_tree(tree1, tree2, ['--full-index']),
691                    quiet)
692     def merge(self, base, ours, theirs, current = None):
693         """Use the index (and only the index) to do a 3-way merge of the
694         L{Tree}s C{base}, C{ours} and C{theirs}. The merge will either
695         succeed (in which case the first half of the return value is
696         the resulting tree) or fail cleanly (in which case the first
697         half of the return value is C{None}).
698
699         If C{current} is given (and not C{None}), it is assumed to be
700         the L{Tree} currently stored in the index; this information is
701         used to avoid having to read the right tree (either of C{ours}
702         and C{theirs}) into the index if it's already there. The
703         second half of the return value is the tree now stored in the
704         index, or C{None} if unknown. If the merge succeeded, this is
705         often the merge result."""
706         assert isinstance(base, Tree)
707         assert isinstance(ours, Tree)
708         assert isinstance(theirs, Tree)
709         assert current == None or isinstance(current, Tree)
710
711         # Take care of the really trivial cases.
712         if base == ours:
713             return (theirs, current)
714         if base == theirs:
715             return (ours, current)
716         if ours == theirs:
717             return (ours, current)
718
719         if current == theirs:
720             # Swap the trees. It doesn't matter since merging is
721             # symmetric, and will allow us to avoid the read_tree()
722             # call below.
723             ours, theirs = theirs, ours
724         if current != ours:
725             self.read_tree(ours)
726         try:
727             self.apply_treediff(base, theirs, quiet = True)
728             result = self.write_tree()
729             return (result, result)
730         except MergeException:
731             return (None, ours)
732     def delete(self):
733         if os.path.isfile(self.__filename):
734             os.remove(self.__filename)
735     def conflicts(self):
736         """The set of conflicting paths."""
737         paths = set()
738         for line in self.run(['git', 'ls-files', '-z', '--unmerged']
739                              ).raw_output().split('\0')[:-1]:
740             stat, path = line.split('\t', 1)
741             paths.add(path)
742         return paths
743
744 class Worktree(object):
745     """Represents a git worktree (that is, a checked-out file tree)."""
746     def __init__(self, directory):
747         self.__directory = directory
748     env = property(lambda self: { 'GIT_WORK_TREE': '.' })
749     directory = property(lambda self: self.__directory)
750
751 class CheckoutException(exception.StgException):
752     """Exception raised when a checkout fails."""
753
754 class IndexAndWorktree(RunWithEnvCwd):
755     """Represents a git index and a worktree. Anything that an index or
756     worktree can do on their own are handled by the L{Index} and
757     L{Worktree} classes; this class concerns itself with the
758     operations that require both."""
759     def __init__(self, index, worktree):
760         self.__index = index
761         self.__worktree = worktree
762     index = property(lambda self: self.__index)
763     env = property(lambda self: utils.add_dict(self.__index.env,
764                                                self.__worktree.env))
765     cwd = property(lambda self: self.__worktree.directory)
766     def checkout(self, old_tree, new_tree):
767         # TODO: Optionally do a 3-way instead of doing nothing when we
768         # have a problem. Or maybe we should stash changes in a patch?
769         assert isinstance(old_tree, Tree)
770         assert isinstance(new_tree, Tree)
771         try:
772             self.run(['git', 'read-tree', '-u', '-m',
773                       '--exclude-per-directory=.gitignore',
774                       old_tree.sha1, new_tree.sha1]
775                      ).discard_output()
776         except run.RunException:
777             raise CheckoutException('Index/workdir dirty')
778     def merge(self, base, ours, theirs):
779         assert isinstance(base, Tree)
780         assert isinstance(ours, Tree)
781         assert isinstance(theirs, Tree)
782         try:
783             r = self.run(['git', 'merge-recursive', base.sha1, '--', ours.sha1,
784                           theirs.sha1],
785                          env = { 'GITHEAD_%s' % base.sha1: 'ancestor',
786                                  'GITHEAD_%s' % ours.sha1: 'current',
787                                  'GITHEAD_%s' % theirs.sha1: 'patched'})
788             r.discard_output()
789         except run.RunException, e:
790             if r.exitcode == 1:
791                 raise MergeConflictException()
792             else:
793                 raise MergeException('Index/worktree dirty')
794     def changed_files(self):
795         return self.run(['git', 'diff-files', '--name-only']).output_lines()
796     def update_index(self, files):
797         self.run(['git', 'update-index', '--remove', '-z', '--stdin']
798                  ).input_nulterm(files).discard_output()
799
800 class Branch(object):
801     """Represents a Git branch."""
802     def __init__(self, repository, name):
803         self.__repository = repository
804         self.__name = name
805         try:
806             self.head
807         except KeyError:
808             raise BranchException('%s: no such branch' % name)
809
810     name = property(lambda self: self.__name)
811     repository = property(lambda self: self.__repository)
812
813     def __ref(self):
814         return 'refs/heads/%s' % self.__name
815     @property
816     def head(self):
817         return self.__repository.refs.get(self.__ref())
818     def set_head(self, commit, msg):
819         self.__repository.refs.set(self.__ref(), commit, msg)
820
821     def set_parent_remote(self, name):
822         value = config.set('branch.%s.remote' % self.__name, name)
823     def set_parent_branch(self, name):
824         if config.get('branch.%s.remote' % self.__name):
825             # Never set merge if remote is not set to avoid
826             # possibly-erroneous lookups into 'origin'
827             config.set('branch.%s.merge' % self.__name, name)
828
829     @classmethod
830     def create(cls, repository, name, create_at = None):
831         """Create a new Git branch and return the corresponding
832         L{Branch} object."""
833         try:
834             branch = cls(repository, name)
835         except BranchException:
836             branch = None
837         if branch:
838             raise BranchException('%s: branch already exists' % name)
839
840         cmd = ['git', 'branch']
841         if create_at:
842             cmd.append(create_at.sha1)
843         repository.run(['git', 'branch', create_at.sha1]).discard_output()
844
845         return cls(repository, name)