chiark / gitweb /
b3181cbe02d595e40de34849d2f38a89f00b702d
[stgit] / stgit / lib / git.py
1 import os, os.path, re
2 from datetime import datetime, timedelta, tzinfo
3
4 from stgit import exception, run, utils
5 from stgit.config import config
6
7 class RepositoryException(exception.StgException):
8     pass
9
10 class DateException(exception.StgException):
11     def __init__(self, string, type):
12         exception.StgException.__init__(
13             self, '"%s" is not a valid %s' % (string, type))
14
15 class DetachedHeadException(RepositoryException):
16     def __init__(self):
17         RepositoryException.__init__(self, 'Not on any branch')
18
19 class Repr(object):
20     def __repr__(self):
21         return str(self)
22
23 class NoValue(object):
24     pass
25
26 def make_defaults(defaults):
27     def d(val, attr, default_fun = lambda: None):
28         if val != NoValue:
29             return val
30         elif defaults != NoValue:
31             return getattr(defaults, attr)
32         else:
33             return default_fun()
34     return d
35
36 class TimeZone(tzinfo, Repr):
37     def __init__(self, tzstring):
38         m = re.match(r'^([+-])(\d{2}):?(\d{2})$', tzstring)
39         if not m:
40             raise DateException(tzstring, 'time zone')
41         sign = int(m.group(1) + '1')
42         try:
43             self.__offset = timedelta(hours = sign*int(m.group(2)),
44                                       minutes = sign*int(m.group(3)))
45         except OverflowError:
46             raise DateException(tzstring, 'time zone')
47         self.__name = tzstring
48     def utcoffset(self, dt):
49         return self.__offset
50     def tzname(self, dt):
51         return self.__name
52     def dst(self, dt):
53         return timedelta(0)
54     def __str__(self):
55         return self.__name
56
57 class Date(Repr):
58     """Immutable."""
59     def __init__(self, datestring):
60         # Try git-formatted date.
61         m = re.match(r'^(\d+)\s+([+-]\d\d:?\d\d)$', datestring)
62         if m:
63             try:
64                 self.__time = datetime.fromtimestamp(int(m.group(1)),
65                                                      TimeZone(m.group(2)))
66             except ValueError:
67                 raise DateException(datestring, 'date')
68             return
69
70         # Try iso-formatted date.
71         m = re.match(r'^(\d{4})-(\d{2})-(\d{2})\s+(\d{2}):(\d{2}):(\d{2})\s+'
72                      + r'([+-]\d\d:?\d\d)$', datestring)
73         if m:
74             try:
75                 self.__time = datetime(
76                     *[int(m.group(i + 1)) for i in xrange(6)],
77                     **{'tzinfo': TimeZone(m.group(7))})
78             except ValueError:
79                 raise DateException(datestring, 'date')
80             return
81
82         raise DateException(datestring, 'date')
83     def __str__(self):
84         return self.isoformat()
85     def isoformat(self):
86         """Human-friendly ISO 8601 format."""
87         return '%s %s' % (self.__time.replace(tzinfo = None).isoformat(' '),
88                           self.__time.tzinfo)
89     @classmethod
90     def maybe(cls, datestring):
91         if datestring in [None, NoValue]:
92             return datestring
93         return cls(datestring)
94
95 class Person(Repr):
96     """Immutable."""
97     def __init__(self, name = NoValue, email = NoValue,
98                  date = NoValue, defaults = NoValue):
99         d = make_defaults(defaults)
100         self.__name = d(name, 'name')
101         self.__email = d(email, 'email')
102         self.__date = d(date, 'date')
103         assert isinstance(self.__date, Date) or self.__date in [None, NoValue]
104     name = property(lambda self: self.__name)
105     email = property(lambda self: self.__email)
106     date = property(lambda self: self.__date)
107     def set_name(self, name):
108         return type(self)(name = name, defaults = self)
109     def set_email(self, email):
110         return type(self)(email = email, defaults = self)
111     def set_date(self, date):
112         return type(self)(date = date, defaults = self)
113     def __str__(self):
114         return '%s <%s> %s' % (self.name, self.email, self.date)
115     @classmethod
116     def parse(cls, s):
117         m = re.match(r'^([^<]*)<([^>]*)>\s+(\d+\s+[+-]\d{4})$', s)
118         assert m
119         name = m.group(1).strip()
120         email = m.group(2)
121         date = Date(m.group(3))
122         return cls(name, email, date)
123     @classmethod
124     def user(cls):
125         if not hasattr(cls, '__user'):
126             cls.__user = cls(name = config.get('user.name'),
127                              email = config.get('user.email'))
128         return cls.__user
129     @classmethod
130     def author(cls):
131         if not hasattr(cls, '__author'):
132             cls.__author = cls(
133                 name = os.environ.get('GIT_AUTHOR_NAME', NoValue),
134                 email = os.environ.get('GIT_AUTHOR_EMAIL', NoValue),
135                 date = Date.maybe(os.environ.get('GIT_AUTHOR_DATE', NoValue)),
136                 defaults = cls.user())
137         return cls.__author
138     @classmethod
139     def committer(cls):
140         if not hasattr(cls, '__committer'):
141             cls.__committer = cls(
142                 name = os.environ.get('GIT_COMMITTER_NAME', NoValue),
143                 email = os.environ.get('GIT_COMMITTER_EMAIL', NoValue),
144                 date = Date.maybe(
145                     os.environ.get('GIT_COMMITTER_DATE', NoValue)),
146                 defaults = cls.user())
147         return cls.__committer
148
149 class Tree(Repr):
150     """Immutable."""
151     def __init__(self, sha1):
152         self.__sha1 = sha1
153     sha1 = property(lambda self: self.__sha1)
154     def __str__(self):
155         return 'Tree<%s>' % self.sha1
156
157 class CommitData(Repr):
158     """Immutable."""
159     def __init__(self, tree = NoValue, parents = NoValue, author = NoValue,
160                  committer = NoValue, message = NoValue, defaults = NoValue):
161         d = make_defaults(defaults)
162         self.__tree = d(tree, 'tree')
163         self.__parents = d(parents, 'parents')
164         self.__author = d(author, 'author', Person.author)
165         self.__committer = d(committer, 'committer', Person.committer)
166         self.__message = d(message, 'message')
167     tree = property(lambda self: self.__tree)
168     parents = property(lambda self: self.__parents)
169     @property
170     def parent(self):
171         assert len(self.__parents) == 1
172         return self.__parents[0]
173     author = property(lambda self: self.__author)
174     committer = property(lambda self: self.__committer)
175     message = property(lambda self: self.__message)
176     def set_tree(self, tree):
177         return type(self)(tree = tree, defaults = self)
178     def set_parents(self, parents):
179         return type(self)(parents = parents, defaults = self)
180     def add_parent(self, parent):
181         return type(self)(parents = list(self.parents or []) + [parent],
182                           defaults = self)
183     def set_parent(self, parent):
184         return self.set_parents([parent])
185     def set_author(self, author):
186         return type(self)(author = author, defaults = self)
187     def set_committer(self, committer):
188         return type(self)(committer = committer, defaults = self)
189     def set_message(self, message):
190         return type(self)(message = message, defaults = self)
191     def is_nochange(self):
192         return len(self.parents) == 1 and self.tree == self.parent.data.tree
193     def __str__(self):
194         if self.tree == None:
195             tree = None
196         else:
197             tree = self.tree.sha1
198         if self.parents == None:
199             parents = None
200         else:
201             parents = [p.sha1 for p in self.parents]
202         return ('CommitData<tree: %s, parents: %s, author: %s,'
203                 ' committer: %s, message: "%s">'
204                 ) % (tree, parents, self.author, self.committer, self.message)
205     @classmethod
206     def parse(cls, repository, s):
207         cd = cls(parents = [])
208         lines = list(s.splitlines(True))
209         for i in xrange(len(lines)):
210             line = lines[i].strip()
211             if not line:
212                 return cd.set_message(''.join(lines[i+1:]))
213             key, value = line.split(None, 1)
214             if key == 'tree':
215                 cd = cd.set_tree(repository.get_tree(value))
216             elif key == 'parent':
217                 cd = cd.add_parent(repository.get_commit(value))
218             elif key == 'author':
219                 cd = cd.set_author(Person.parse(value))
220             elif key == 'committer':
221                 cd = cd.set_committer(Person.parse(value))
222             else:
223                 assert False
224         assert False
225
226 class Commit(Repr):
227     """Immutable."""
228     def __init__(self, repository, sha1):
229         self.__sha1 = sha1
230         self.__repository = repository
231         self.__data = None
232     sha1 = property(lambda self: self.__sha1)
233     @property
234     def data(self):
235         if self.__data == None:
236             self.__data = CommitData.parse(
237                 self.__repository,
238                 self.__repository.cat_object(self.sha1))
239         return self.__data
240     def __str__(self):
241         return 'Commit<sha1: %s, data: %s>' % (self.sha1, self.__data)
242
243 class Refs(object):
244     def __init__(self, repository):
245         self.__repository = repository
246         self.__refs = None
247     def __cache_refs(self):
248         self.__refs = {}
249         for line in self.__repository.run(['git', 'show-ref']).output_lines():
250             m = re.match(r'^([0-9a-f]{40})\s+(\S+)$', line)
251             sha1, ref = m.groups()
252             self.__refs[ref] = sha1
253     def get(self, ref):
254         """Throws KeyError if ref doesn't exist."""
255         if self.__refs == None:
256             self.__cache_refs()
257         return self.__repository.get_commit(self.__refs[ref])
258     def exists(self, ref):
259         try:
260             self.get(ref)
261         except KeyError:
262             return False
263         else:
264             return True
265     def set(self, ref, commit, msg):
266         if self.__refs == None:
267             self.__cache_refs()
268         old_sha1 = self.__refs.get(ref, '0'*40)
269         new_sha1 = commit.sha1
270         if old_sha1 != new_sha1:
271             self.__repository.run(['git', 'update-ref', '-m', msg,
272                                    ref, new_sha1, old_sha1]).no_output()
273             self.__refs[ref] = new_sha1
274     def delete(self, ref):
275         if self.__refs == None:
276             self.__cache_refs()
277         self.__repository.run(['git', 'update-ref',
278                                '-d', ref, self.__refs[ref]]).no_output()
279         del self.__refs[ref]
280
281 class ObjectCache(object):
282     """Cache for Python objects, for making sure that we create only one
283     Python object per git object."""
284     def __init__(self, create):
285         self.__objects = {}
286         self.__create = create
287     def __getitem__(self, name):
288         if not name in self.__objects:
289             self.__objects[name] = self.__create(name)
290         return self.__objects[name]
291     def __contains__(self, name):
292         return name in self.__objects
293     def __setitem__(self, name, val):
294         assert not name in self.__objects
295         self.__objects[name] = val
296
297 class RunWithEnv(object):
298     def run(self, args, env = {}):
299         return run.Run(*args).env(utils.add_dict(self.env, env))
300
301 class RunWithEnvCwd(RunWithEnv):
302     def run(self, args, env = {}):
303         return RunWithEnv.run(self, args, env).cwd(self.cwd)
304
305 class Repository(RunWithEnv):
306     def __init__(self, directory):
307         self.__git_dir = directory
308         self.__refs = Refs(self)
309         self.__trees = ObjectCache(lambda sha1: Tree(sha1))
310         self.__commits = ObjectCache(lambda sha1: Commit(self, sha1))
311         self.__default_index = None
312         self.__default_worktree = None
313         self.__default_iw = None
314     env = property(lambda self: { 'GIT_DIR': self.__git_dir })
315     @classmethod
316     def default(cls):
317         """Return the default repository."""
318         try:
319             return cls(run.Run('git', 'rev-parse', '--git-dir'
320                                ).output_one_line())
321         except run.RunException:
322             raise RepositoryException('Cannot find git repository')
323     @property
324     def default_index(self):
325         if self.__default_index == None:
326             self.__default_index = Index(
327                 self, (os.environ.get('GIT_INDEX_FILE', None)
328                        or os.path.join(self.__git_dir, 'index')))
329         return self.__default_index
330     def temp_index(self):
331         return Index(self, self.__git_dir)
332     @property
333     def default_worktree(self):
334         if self.__default_worktree == None:
335             path = os.environ.get('GIT_WORK_TREE', None)
336             if not path:
337                 o = run.Run('git', 'rev-parse', '--show-cdup').output_lines()
338                 o = o or ['.']
339                 assert len(o) == 1
340                 path = o[0]
341             self.__default_worktree = Worktree(path)
342         return self.__default_worktree
343     @property
344     def default_iw(self):
345         if self.__default_iw == None:
346             self.__default_iw = IndexAndWorktree(self.default_index,
347                                                  self.default_worktree)
348         return self.__default_iw
349     directory = property(lambda self: self.__git_dir)
350     refs = property(lambda self: self.__refs)
351     def cat_object(self, sha1):
352         return self.run(['git', 'cat-file', '-p', sha1]).raw_output()
353     def rev_parse(self, rev):
354         try:
355             return self.get_commit(self.run(
356                     ['git', 'rev-parse', '%s^{commit}' % rev]
357                     ).output_one_line())
358         except run.RunException:
359             raise RepositoryException('%s: No such revision' % rev)
360     def get_tree(self, sha1):
361         return self.__trees[sha1]
362     def get_commit(self, sha1):
363         return self.__commits[sha1]
364     def commit(self, commitdata):
365         c = ['git', 'commit-tree', commitdata.tree.sha1]
366         for p in commitdata.parents:
367             c.append('-p')
368             c.append(p.sha1)
369         env = {}
370         for p, v1 in ((commitdata.author, 'AUTHOR'),
371                        (commitdata.committer, 'COMMITTER')):
372             if p != None:
373                 for attr, v2 in (('name', 'NAME'), ('email', 'EMAIL'),
374                                  ('date', 'DATE')):
375                     if getattr(p, attr) != None:
376                         env['GIT_%s_%s' % (v1, v2)] = str(getattr(p, attr))
377         sha1 = self.run(c, env = env).raw_input(commitdata.message
378                                                 ).output_one_line()
379         return self.get_commit(sha1)
380     @property
381     def head(self):
382         try:
383             return self.run(['git', 'symbolic-ref', '-q', 'HEAD']
384                             ).output_one_line()
385         except run.RunException:
386             raise DetachedHeadException()
387     def set_head(self, ref, msg):
388         self.run(['git', 'symbolic-ref', '-m', msg, 'HEAD', ref]).no_output()
389     def simple_merge(self, base, ours, theirs):
390         """Given three trees, tries to do an in-index merge in a temporary
391         index with a temporary index. Returns the result tree, or None if
392         the merge failed (due to conflicts)."""
393         assert isinstance(base, Tree)
394         assert isinstance(ours, Tree)
395         assert isinstance(theirs, Tree)
396
397         # Take care of the really trivial cases.
398         if base == ours:
399             return theirs
400         if base == theirs:
401             return ours
402         if ours == theirs:
403             return ours
404
405         index = self.temp_index()
406         try:
407             index.merge(base, ours, theirs)
408             try:
409                 return index.write_tree()
410             except MergeException:
411                 return None
412         finally:
413             index.delete()
414     def apply(self, tree, patch_text):
415         """Given a tree and a patch, will either return the new tree that
416         results when the patch is applied, or None if the patch
417         couldn't be applied."""
418         assert isinstance(tree, Tree)
419         if not patch_text:
420             return tree
421         index = self.temp_index()
422         try:
423             index.read_tree(tree)
424             try:
425                 index.apply(patch_text)
426                 return index.write_tree()
427             except MergeException:
428                 return None
429         finally:
430             index.delete()
431     def diff_tree(self, t1, t2, diff_opts):
432         assert isinstance(t1, Tree)
433         assert isinstance(t2, Tree)
434         return self.run(['git', 'diff-tree', '-p'] + list(diff_opts)
435                         + [t1.sha1, t2.sha1]).raw_output()
436
437 class MergeException(exception.StgException):
438     pass
439
440 class MergeConflictException(MergeException):
441     pass
442
443 class Index(RunWithEnv):
444     def __init__(self, repository, filename):
445         self.__repository = repository
446         if os.path.isdir(filename):
447             # Create a temp index in the given directory.
448             self.__filename = os.path.join(
449                 filename, 'index.temp-%d-%x' % (os.getpid(), id(self)))
450             self.delete()
451         else:
452             self.__filename = filename
453     env = property(lambda self: utils.add_dict(
454             self.__repository.env, { 'GIT_INDEX_FILE': self.__filename }))
455     def read_tree(self, tree):
456         self.run(['git', 'read-tree', tree.sha1]).no_output()
457     def write_tree(self):
458         try:
459             return self.__repository.get_tree(
460                 self.run(['git', 'write-tree']).discard_stderr(
461                     ).output_one_line())
462         except run.RunException:
463             raise MergeException('Conflicting merge')
464     def is_clean(self):
465         try:
466             self.run(['git', 'update-index', '--refresh']).discard_output()
467         except run.RunException:
468             return False
469         else:
470             return True
471     def merge(self, base, ours, theirs):
472         """In-index merge, no worktree involved."""
473         self.run(['git', 'read-tree', '-m', '-i', '--aggressive',
474                   base.sha1, ours.sha1, theirs.sha1]).no_output()
475     def apply(self, patch_text):
476         """In-index patch application, no worktree involved."""
477         try:
478             self.run(['git', 'apply', '--cached']
479                      ).raw_input(patch_text).no_output()
480         except run.RunException:
481             raise MergeException('Patch does not apply cleanly')
482     def delete(self):
483         if os.path.isfile(self.__filename):
484             os.remove(self.__filename)
485     def conflicts(self):
486         """The set of conflicting paths."""
487         paths = set()
488         for line in self.run(['git', 'ls-files', '-z', '--unmerged']
489                              ).raw_output().split('\0')[:-1]:
490             stat, path = line.split('\t', 1)
491             paths.add(path)
492         return paths
493
494 class Worktree(object):
495     def __init__(self, directory):
496         self.__directory = directory
497     env = property(lambda self: { 'GIT_WORK_TREE': '.' })
498     directory = property(lambda self: self.__directory)
499
500 class CheckoutException(exception.StgException):
501     pass
502
503 class IndexAndWorktree(RunWithEnvCwd):
504     def __init__(self, index, worktree):
505         self.__index = index
506         self.__worktree = worktree
507     index = property(lambda self: self.__index)
508     env = property(lambda self: utils.add_dict(self.__index.env,
509                                                self.__worktree.env))
510     cwd = property(lambda self: self.__worktree.directory)
511     def checkout(self, old_tree, new_tree):
512         # TODO: Optionally do a 3-way instead of doing nothing when we
513         # have a problem. Or maybe we should stash changes in a patch?
514         assert isinstance(old_tree, Tree)
515         assert isinstance(new_tree, Tree)
516         try:
517             self.run(['git', 'read-tree', '-u', '-m',
518                       '--exclude-per-directory=.gitignore',
519                       old_tree.sha1, new_tree.sha1]
520                      ).discard_output()
521         except run.RunException:
522             raise CheckoutException('Index/workdir dirty')
523     def merge(self, base, ours, theirs):
524         assert isinstance(base, Tree)
525         assert isinstance(ours, Tree)
526         assert isinstance(theirs, Tree)
527         try:
528             r = self.run(['git', 'merge-recursive', base.sha1, '--', ours.sha1,
529                           theirs.sha1],
530                          env = { 'GITHEAD_%s' % base.sha1: 'ancestor',
531                                  'GITHEAD_%s' % ours.sha1: 'current',
532                                  'GITHEAD_%s' % theirs.sha1: 'patched'})
533             r.discard_output()
534         except run.RunException, e:
535             if r.exitcode == 1:
536                 raise MergeConflictException()
537             else:
538                 raise MergeException('Index/worktree dirty')
539     def changed_files(self):
540         return self.run(['git', 'diff-files', '--name-only']).output_lines()
541     def update_index(self, files):
542         self.run(['git', 'update-index', '--remove', '-z', '--stdin']
543                  ).input_nulterm(files).discard_output()