chiark / gitweb /
fd66f6d18e028dda3eeb880fa2c77952956ae63e
[stgit] / stgit / lib / git.py
1 """A Python class hierarchy wrapping a git repository and its
2 contents."""
3
4 import os, os.path, re
5 from datetime import datetime, timedelta, tzinfo
6
7 from stgit import exception, run, utils
8 from stgit.config import config
9
10 class Immutable(object):
11     """I{Immutable} objects cannot be modified once created. Any
12     modification methods will return a new object, leaving the
13     original object as it was.
14
15     The reason for this is that we want to be able to represent git
16     objects, which are immutable, and want to be able to create new
17     git objects that are just slight modifications of other git
18     objects. (Such as, for example, modifying the commit message of a
19     commit object while leaving the rest of it intact. This involves
20     creating a whole new commit object that's exactly like the old one
21     except for the commit message.)
22
23     The L{Immutable} class doesn't acytually enforce immutability --
24     that is up to the individual immutable subclasses. It just serves
25     as documentation."""
26
27 class RepositoryException(exception.StgException):
28     """Base class for all exceptions due to failed L{Repository}
29     operations."""
30
31 class DateException(exception.StgException):
32     """Exception raised when a date+time string could not be parsed."""
33     def __init__(self, string, type):
34         exception.StgException.__init__(
35             self, '"%s" is not a valid %s' % (string, type))
36
37 class DetachedHeadException(RepositoryException):
38     """Exception raised when HEAD is detached (that is, there is no
39     current branch)."""
40     def __init__(self):
41         RepositoryException.__init__(self, 'Not on any branch')
42
43 class Repr(object):
44     """Utility class that defines C{__reps__} in terms of C{__str__}."""
45     def __repr__(self):
46         return str(self)
47
48 class NoValue(object):
49     """A handy default value that is guaranteed to be distinct from any
50     real argument value."""
51     pass
52
53 def make_defaults(defaults):
54     def d(val, attr, default_fun = lambda: None):
55         if val != NoValue:
56             return val
57         elif defaults != NoValue:
58             return getattr(defaults, attr)
59         else:
60             return default_fun()
61     return d
62
63 class TimeZone(tzinfo, Repr):
64     """A simple time zone class for static offsets from UTC. (We have to
65     define our own since Python's standard library doesn't define any
66     time zone classes.)"""
67     def __init__(self, tzstring):
68         m = re.match(r'^([+-])(\d{2}):?(\d{2})$', tzstring)
69         if not m:
70             raise DateException(tzstring, 'time zone')
71         sign = int(m.group(1) + '1')
72         try:
73             self.__offset = timedelta(hours = sign*int(m.group(2)),
74                                       minutes = sign*int(m.group(3)))
75         except OverflowError:
76             raise DateException(tzstring, 'time zone')
77         self.__name = tzstring
78     def utcoffset(self, dt):
79         return self.__offset
80     def tzname(self, dt):
81         return self.__name
82     def dst(self, dt):
83         return timedelta(0)
84     def __str__(self):
85         return self.__name
86
87 class Date(Immutable, Repr):
88     """Represents a timestamp used in git commits."""
89     def __init__(self, datestring):
90         # Try git-formatted date.
91         m = re.match(r'^(\d+)\s+([+-]\d\d:?\d\d)$', datestring)
92         if m:
93             try:
94                 self.__time = datetime.fromtimestamp(int(m.group(1)),
95                                                      TimeZone(m.group(2)))
96             except ValueError:
97                 raise DateException(datestring, 'date')
98             return
99
100         # Try iso-formatted date.
101         m = re.match(r'^(\d{4})-(\d{2})-(\d{2})\s+(\d{2}):(\d{2}):(\d{2})\s+'
102                      + r'([+-]\d\d:?\d\d)$', datestring)
103         if m:
104             try:
105                 self.__time = datetime(
106                     *[int(m.group(i + 1)) for i in xrange(6)],
107                     **{'tzinfo': TimeZone(m.group(7))})
108             except ValueError:
109                 raise DateException(datestring, 'date')
110             return
111
112         raise DateException(datestring, 'date')
113     def __str__(self):
114         return self.isoformat()
115     def isoformat(self):
116         """Human-friendly ISO 8601 format."""
117         return '%s %s' % (self.__time.replace(tzinfo = None).isoformat(' '),
118                           self.__time.tzinfo)
119     @classmethod
120     def maybe(cls, datestring):
121         """Return a new object initialized with the argument if it contains a
122         value (otherwise, just return the argument)."""
123         if datestring in [None, NoValue]:
124             return datestring
125         return cls(datestring)
126
127 class Person(Immutable, Repr):
128     """Represents an author or committer in a git commit object. Contains
129     name, email and timestamp."""
130     def __init__(self, name = NoValue, email = NoValue,
131                  date = NoValue, defaults = NoValue):
132         d = make_defaults(defaults)
133         self.__name = d(name, 'name')
134         self.__email = d(email, 'email')
135         self.__date = d(date, 'date')
136         assert isinstance(self.__date, Date) or self.__date in [None, NoValue]
137     name = property(lambda self: self.__name)
138     email = property(lambda self: self.__email)
139     date = property(lambda self: self.__date)
140     def set_name(self, name):
141         return type(self)(name = name, defaults = self)
142     def set_email(self, email):
143         return type(self)(email = email, defaults = self)
144     def set_date(self, date):
145         return type(self)(date = date, defaults = self)
146     def __str__(self):
147         return '%s <%s> %s' % (self.name, self.email, self.date)
148     @classmethod
149     def parse(cls, s):
150         m = re.match(r'^([^<]*)<([^>]*)>\s+(\d+\s+[+-]\d{4})$', s)
151         assert m
152         name = m.group(1).strip()
153         email = m.group(2)
154         date = Date(m.group(3))
155         return cls(name, email, date)
156     @classmethod
157     def user(cls):
158         if not hasattr(cls, '__user'):
159             cls.__user = cls(name = config.get('user.name'),
160                              email = config.get('user.email'))
161         return cls.__user
162     @classmethod
163     def author(cls):
164         if not hasattr(cls, '__author'):
165             cls.__author = cls(
166                 name = os.environ.get('GIT_AUTHOR_NAME', NoValue),
167                 email = os.environ.get('GIT_AUTHOR_EMAIL', NoValue),
168                 date = Date.maybe(os.environ.get('GIT_AUTHOR_DATE', NoValue)),
169                 defaults = cls.user())
170         return cls.__author
171     @classmethod
172     def committer(cls):
173         if not hasattr(cls, '__committer'):
174             cls.__committer = cls(
175                 name = os.environ.get('GIT_COMMITTER_NAME', NoValue),
176                 email = os.environ.get('GIT_COMMITTER_EMAIL', NoValue),
177                 date = Date.maybe(
178                     os.environ.get('GIT_COMMITTER_DATE', NoValue)),
179                 defaults = cls.user())
180         return cls.__committer
181
182 class Tree(Immutable, Repr):
183     """Represents a git tree object."""
184     def __init__(self, sha1):
185         self.__sha1 = sha1
186     sha1 = property(lambda self: self.__sha1)
187     def __str__(self):
188         return 'Tree<%s>' % self.sha1
189
190 class CommitData(Immutable, Repr):
191     """Represents the actual data contents of a git commit object."""
192     def __init__(self, tree = NoValue, parents = NoValue, author = NoValue,
193                  committer = NoValue, message = NoValue, defaults = NoValue):
194         d = make_defaults(defaults)
195         self.__tree = d(tree, 'tree')
196         self.__parents = d(parents, 'parents')
197         self.__author = d(author, 'author', Person.author)
198         self.__committer = d(committer, 'committer', Person.committer)
199         self.__message = d(message, 'message')
200     tree = property(lambda self: self.__tree)
201     parents = property(lambda self: self.__parents)
202     @property
203     def parent(self):
204         assert len(self.__parents) == 1
205         return self.__parents[0]
206     author = property(lambda self: self.__author)
207     committer = property(lambda self: self.__committer)
208     message = property(lambda self: self.__message)
209     def set_tree(self, tree):
210         return type(self)(tree = tree, defaults = self)
211     def set_parents(self, parents):
212         return type(self)(parents = parents, defaults = self)
213     def add_parent(self, parent):
214         return type(self)(parents = list(self.parents or []) + [parent],
215                           defaults = self)
216     def set_parent(self, parent):
217         return self.set_parents([parent])
218     def set_author(self, author):
219         return type(self)(author = author, defaults = self)
220     def set_committer(self, committer):
221         return type(self)(committer = committer, defaults = self)
222     def set_message(self, message):
223         return type(self)(message = message, defaults = self)
224     def is_nochange(self):
225         return len(self.parents) == 1 and self.tree == self.parent.data.tree
226     def __str__(self):
227         if self.tree == None:
228             tree = None
229         else:
230             tree = self.tree.sha1
231         if self.parents == None:
232             parents = None
233         else:
234             parents = [p.sha1 for p in self.parents]
235         return ('CommitData<tree: %s, parents: %s, author: %s,'
236                 ' committer: %s, message: "%s">'
237                 ) % (tree, parents, self.author, self.committer, self.message)
238     @classmethod
239     def parse(cls, repository, s):
240         cd = cls(parents = [])
241         lines = list(s.splitlines(True))
242         for i in xrange(len(lines)):
243             line = lines[i].strip()
244             if not line:
245                 return cd.set_message(''.join(lines[i+1:]))
246             key, value = line.split(None, 1)
247             if key == 'tree':
248                 cd = cd.set_tree(repository.get_tree(value))
249             elif key == 'parent':
250                 cd = cd.add_parent(repository.get_commit(value))
251             elif key == 'author':
252                 cd = cd.set_author(Person.parse(value))
253             elif key == 'committer':
254                 cd = cd.set_committer(Person.parse(value))
255             else:
256                 assert False
257         assert False
258
259 class Commit(Immutable, Repr):
260     """Represents a git commit object. All the actual data contents of the
261     commit object is stored in the L{data} member, which is a
262     L{CommitData} object."""
263     def __init__(self, repository, sha1):
264         self.__sha1 = sha1
265         self.__repository = repository
266         self.__data = None
267     sha1 = property(lambda self: self.__sha1)
268     @property
269     def data(self):
270         if self.__data == None:
271             self.__data = CommitData.parse(
272                 self.__repository,
273                 self.__repository.cat_object(self.sha1))
274         return self.__data
275     def __str__(self):
276         return 'Commit<sha1: %s, data: %s>' % (self.sha1, self.__data)
277
278 class Refs(object):
279     """Accessor for the refs stored in a git repository. Will
280     transparently cache the values of all refs."""
281     def __init__(self, repository):
282         self.__repository = repository
283         self.__refs = None
284     def __cache_refs(self):
285         """(Re-)Build the cache of all refs in the repository."""
286         self.__refs = {}
287         for line in self.__repository.run(['git', 'show-ref']).output_lines():
288             m = re.match(r'^([0-9a-f]{40})\s+(\S+)$', line)
289             sha1, ref = m.groups()
290             self.__refs[ref] = sha1
291     def get(self, ref):
292         """Get the Commit the given ref points to. Throws KeyError if ref
293         doesn't exist."""
294         if self.__refs == None:
295             self.__cache_refs()
296         return self.__repository.get_commit(self.__refs[ref])
297     def exists(self, ref):
298         """Check if the given ref exists."""
299         try:
300             self.get(ref)
301         except KeyError:
302             return False
303         else:
304             return True
305     def set(self, ref, commit, msg):
306         """Write the sha1 of the given Commit to the ref. The ref may or may
307         not already exist."""
308         if self.__refs == None:
309             self.__cache_refs()
310         old_sha1 = self.__refs.get(ref, '0'*40)
311         new_sha1 = commit.sha1
312         if old_sha1 != new_sha1:
313             self.__repository.run(['git', 'update-ref', '-m', msg,
314                                    ref, new_sha1, old_sha1]).no_output()
315             self.__refs[ref] = new_sha1
316     def delete(self, ref):
317         """Delete the given ref. Throws KeyError if ref doesn't exist."""
318         if self.__refs == None:
319             self.__cache_refs()
320         self.__repository.run(['git', 'update-ref',
321                                '-d', ref, self.__refs[ref]]).no_output()
322         del self.__refs[ref]
323
324 class ObjectCache(object):
325     """Cache for Python objects, for making sure that we create only one
326     Python object per git object. This reduces memory consumption and
327     makes object comparison very cheap."""
328     def __init__(self, create):
329         self.__objects = {}
330         self.__create = create
331     def __getitem__(self, name):
332         if not name in self.__objects:
333             self.__objects[name] = self.__create(name)
334         return self.__objects[name]
335     def __contains__(self, name):
336         return name in self.__objects
337     def __setitem__(self, name, val):
338         assert not name in self.__objects
339         self.__objects[name] = val
340
341 class RunWithEnv(object):
342     def run(self, args, env = {}):
343         """Run the given command with an environment given by self.env.
344
345         @type args: list of strings
346         @param args: Command and argument vector
347         @type env: dict
348         @param env: Extra environment"""
349         return run.Run(*args).env(utils.add_dict(self.env, env))
350
351 class RunWithEnvCwd(RunWithEnv):
352     def run(self, args, env = {}):
353         """Run the given command with an environment given by self.env, and
354         current working directory given by self.cwd.
355
356         @type args: list of strings
357         @param args: Command and argument vector
358         @type env: dict
359         @param env: Extra environment"""
360         return RunWithEnv.run(self, args, env).cwd(self.cwd)
361
362 class Repository(RunWithEnv):
363     """Represents a git repository."""
364     def __init__(self, directory):
365         self.__git_dir = directory
366         self.__refs = Refs(self)
367         self.__trees = ObjectCache(lambda sha1: Tree(sha1))
368         self.__commits = ObjectCache(lambda sha1: Commit(self, sha1))
369         self.__default_index = None
370         self.__default_worktree = None
371         self.__default_iw = None
372     env = property(lambda self: { 'GIT_DIR': self.__git_dir })
373     @classmethod
374     def default(cls):
375         """Return the default repository."""
376         try:
377             return cls(run.Run('git', 'rev-parse', '--git-dir'
378                                ).output_one_line())
379         except run.RunException:
380             raise RepositoryException('Cannot find git repository')
381     @property
382     def default_index(self):
383         """An L{Index} object representing the default index file for the
384         repository."""
385         if self.__default_index == None:
386             self.__default_index = Index(
387                 self, (os.environ.get('GIT_INDEX_FILE', None)
388                        or os.path.join(self.__git_dir, 'index')))
389         return self.__default_index
390     def temp_index(self):
391         """Return an L{Index} object representing a new temporary index file
392         for the repository."""
393         return Index(self, self.__git_dir)
394     @property
395     def default_worktree(self):
396         """A L{Worktree} object representing the default work tree."""
397         if self.__default_worktree == None:
398             path = os.environ.get('GIT_WORK_TREE', None)
399             if not path:
400                 o = run.Run('git', 'rev-parse', '--show-cdup').output_lines()
401                 o = o or ['.']
402                 assert len(o) == 1
403                 path = o[0]
404             self.__default_worktree = Worktree(path)
405         return self.__default_worktree
406     @property
407     def default_iw(self):
408         """An L{IndexAndWorktree} object representing the default index and
409         work tree for this repository."""
410         if self.__default_iw == None:
411             self.__default_iw = IndexAndWorktree(self.default_index,
412                                                  self.default_worktree)
413         return self.__default_iw
414     directory = property(lambda self: self.__git_dir)
415     refs = property(lambda self: self.__refs)
416     def cat_object(self, sha1):
417         return self.run(['git', 'cat-file', '-p', sha1]).raw_output()
418     def rev_parse(self, rev):
419         try:
420             return self.get_commit(self.run(
421                     ['git', 'rev-parse', '%s^{commit}' % rev]
422                     ).output_one_line())
423         except run.RunException:
424             raise RepositoryException('%s: No such revision' % rev)
425     def get_tree(self, sha1):
426         return self.__trees[sha1]
427     def get_commit(self, sha1):
428         return self.__commits[sha1]
429     def commit(self, commitdata):
430         c = ['git', 'commit-tree', commitdata.tree.sha1]
431         for p in commitdata.parents:
432             c.append('-p')
433             c.append(p.sha1)
434         env = {}
435         for p, v1 in ((commitdata.author, 'AUTHOR'),
436                        (commitdata.committer, 'COMMITTER')):
437             if p != None:
438                 for attr, v2 in (('name', 'NAME'), ('email', 'EMAIL'),
439                                  ('date', 'DATE')):
440                     if getattr(p, attr) != None:
441                         env['GIT_%s_%s' % (v1, v2)] = str(getattr(p, attr))
442         sha1 = self.run(c, env = env).raw_input(commitdata.message
443                                                 ).output_one_line()
444         return self.get_commit(sha1)
445     @property
446     def head_ref(self):
447         try:
448             return self.run(['git', 'symbolic-ref', '-q', 'HEAD']
449                             ).output_one_line()
450         except run.RunException:
451             raise DetachedHeadException()
452     def set_head_ref(self, ref, msg):
453         self.run(['git', 'symbolic-ref', '-m', msg, 'HEAD', ref]).no_output()
454     def simple_merge(self, base, ours, theirs):
455         """Given three L{Tree}s, tries to do an in-index merge with a
456         temporary index. Returns the result L{Tree}, or None if the
457         merge failed (due to conflicts)."""
458         assert isinstance(base, Tree)
459         assert isinstance(ours, Tree)
460         assert isinstance(theirs, Tree)
461
462         # Take care of the really trivial cases.
463         if base == ours:
464             return theirs
465         if base == theirs:
466             return ours
467         if ours == theirs:
468             return ours
469
470         index = self.temp_index()
471         try:
472             index.merge(base, ours, theirs)
473             try:
474                 return index.write_tree()
475             except MergeException:
476                 return None
477         finally:
478             index.delete()
479     def apply(self, tree, patch_text):
480         """Given a L{Tree} and a patch, will either return the new L{Tree}
481         that results when the patch is applied, or None if the patch
482         couldn't be applied."""
483         assert isinstance(tree, Tree)
484         if not patch_text:
485             return tree
486         index = self.temp_index()
487         try:
488             index.read_tree(tree)
489             try:
490                 index.apply(patch_text)
491                 return index.write_tree()
492             except MergeException:
493                 return None
494         finally:
495             index.delete()
496     def diff_tree(self, t1, t2, diff_opts):
497         """Given two L{Tree}s C{t1} and C{t2}, return the patch that takes
498         C{t1} to C{t2}.
499
500         @type diff_opts: list of strings
501         @param diff_opts: Extra diff options
502         @rtype: String
503         @return: Patch text"""
504         assert isinstance(t1, Tree)
505         assert isinstance(t2, Tree)
506         return self.run(['git', 'diff-tree', '-p'] + list(diff_opts)
507                         + [t1.sha1, t2.sha1]).raw_output()
508
509 class MergeException(exception.StgException):
510     """Exception raised when a merge fails for some reason."""
511
512 class MergeConflictException(MergeException):
513     """Exception raised when a merge fails due to conflicts."""
514
515 class Index(RunWithEnv):
516     """Represents a git index file."""
517     def __init__(self, repository, filename):
518         self.__repository = repository
519         if os.path.isdir(filename):
520             # Create a temp index in the given directory.
521             self.__filename = os.path.join(
522                 filename, 'index.temp-%d-%x' % (os.getpid(), id(self)))
523             self.delete()
524         else:
525             self.__filename = filename
526     env = property(lambda self: utils.add_dict(
527             self.__repository.env, { 'GIT_INDEX_FILE': self.__filename }))
528     def read_tree(self, tree):
529         self.run(['git', 'read-tree', tree.sha1]).no_output()
530     def write_tree(self):
531         try:
532             return self.__repository.get_tree(
533                 self.run(['git', 'write-tree']).discard_stderr(
534                     ).output_one_line())
535         except run.RunException:
536             raise MergeException('Conflicting merge')
537     def is_clean(self):
538         try:
539             self.run(['git', 'update-index', '--refresh']).discard_output()
540         except run.RunException:
541             return False
542         else:
543             return True
544     def merge(self, base, ours, theirs):
545         """In-index merge, no worktree involved."""
546         self.run(['git', 'read-tree', '-m', '-i', '--aggressive',
547                   base.sha1, ours.sha1, theirs.sha1]).no_output()
548     def apply(self, patch_text):
549         """In-index patch application, no worktree involved."""
550         try:
551             self.run(['git', 'apply', '--cached']
552                      ).raw_input(patch_text).no_output()
553         except run.RunException:
554             raise MergeException('Patch does not apply cleanly')
555     def delete(self):
556         if os.path.isfile(self.__filename):
557             os.remove(self.__filename)
558     def conflicts(self):
559         """The set of conflicting paths."""
560         paths = set()
561         for line in self.run(['git', 'ls-files', '-z', '--unmerged']
562                              ).raw_output().split('\0')[:-1]:
563             stat, path = line.split('\t', 1)
564             paths.add(path)
565         return paths
566
567 class Worktree(object):
568     """Represents a git worktree (that is, a checked-out file tree)."""
569     def __init__(self, directory):
570         self.__directory = directory
571     env = property(lambda self: { 'GIT_WORK_TREE': '.' })
572     directory = property(lambda self: self.__directory)
573
574 class CheckoutException(exception.StgException):
575     """Exception raised when a checkout fails."""
576
577 class IndexAndWorktree(RunWithEnvCwd):
578     """Represents a git index and a worktree. Anything that an index or
579     worktree can do on their own are handled by the L{Index} and
580     L{Worktree} classes; this class concerns itself with the
581     operations that require both."""
582     def __init__(self, index, worktree):
583         self.__index = index
584         self.__worktree = worktree
585     index = property(lambda self: self.__index)
586     env = property(lambda self: utils.add_dict(self.__index.env,
587                                                self.__worktree.env))
588     cwd = property(lambda self: self.__worktree.directory)
589     def checkout(self, old_tree, new_tree):
590         # TODO: Optionally do a 3-way instead of doing nothing when we
591         # have a problem. Or maybe we should stash changes in a patch?
592         assert isinstance(old_tree, Tree)
593         assert isinstance(new_tree, Tree)
594         try:
595             self.run(['git', 'read-tree', '-u', '-m',
596                       '--exclude-per-directory=.gitignore',
597                       old_tree.sha1, new_tree.sha1]
598                      ).discard_output()
599         except run.RunException:
600             raise CheckoutException('Index/workdir dirty')
601     def merge(self, base, ours, theirs):
602         assert isinstance(base, Tree)
603         assert isinstance(ours, Tree)
604         assert isinstance(theirs, Tree)
605         try:
606             r = self.run(['git', 'merge-recursive', base.sha1, '--', ours.sha1,
607                           theirs.sha1],
608                          env = { 'GITHEAD_%s' % base.sha1: 'ancestor',
609                                  'GITHEAD_%s' % ours.sha1: 'current',
610                                  'GITHEAD_%s' % theirs.sha1: 'patched'})
611             r.discard_output()
612         except run.RunException, e:
613             if r.exitcode == 1:
614                 raise MergeConflictException()
615             else:
616                 raise MergeException('Index/worktree dirty')
617     def changed_files(self):
618         return self.run(['git', 'diff-files', '--name-only']).output_lines()
619     def update_index(self, files):
620         self.run(['git', 'update-index', '--remove', '-z', '--stdin']
621                  ).input_nulterm(files).discard_output()