chiark / gitweb /
2ca44952c3d820135ac4501bf0c0ab2caced321f
[stgit] / stgit / lib / git.py
1 import os, os.path, re
2 from stgit import exception, run, utils
3 from stgit.config import config
4
5 class RepositoryException(exception.StgException):
6     pass
7
8 class DetachedHeadException(RepositoryException):
9     def __init__(self):
10         RepositoryException.__init__(self, 'Not on any branch')
11
12 class Repr(object):
13     def __repr__(self):
14         return str(self)
15
16 class NoValue(object):
17     pass
18
19 def make_defaults(defaults):
20     def d(val, attr):
21         if val != NoValue:
22             return val
23         elif defaults != NoValue:
24             return getattr(defaults, attr)
25         else:
26             return None
27     return d
28
29 class Person(Repr):
30     """Immutable."""
31     def __init__(self, name = NoValue, email = NoValue,
32                  date = NoValue, defaults = NoValue):
33         d = make_defaults(defaults)
34         self.__name = d(name, 'name')
35         self.__email = d(email, 'email')
36         self.__date = d(date, 'date')
37     name = property(lambda self: self.__name)
38     email = property(lambda self: self.__email)
39     date = property(lambda self: self.__date)
40     def set_name(self, name):
41         return type(self)(name = name, defaults = self)
42     def set_email(self, email):
43         return type(self)(email = email, defaults = self)
44     def set_date(self, date):
45         return type(self)(date = date, defaults = self)
46     def __str__(self):
47         return '%s <%s> %s' % (self.name, self.email, self.date)
48     @classmethod
49     def parse(cls, s):
50         m = re.match(r'^([^<]*)<([^>]*)>\s+(\d+\s+[+-]\d{4})$', s)
51         assert m
52         name = m.group(1).strip()
53         email = m.group(2)
54         date = m.group(3)
55         return cls(name, email, date)
56     @classmethod
57     def user(cls):
58         if not hasattr(cls, '__user'):
59             cls.__user = cls(name = config.get('user.name'),
60                              email = config.get('user.email'))
61         return cls.__user
62     @classmethod
63     def author(cls):
64         if not hasattr(cls, '__author'):
65             cls.__author = cls(
66                 name = os.environ.get('GIT_AUTHOR_NAME', NoValue),
67                 email = os.environ.get('GIT_AUTHOR_EMAIL', NoValue),
68                 date = os.environ.get('GIT_AUTHOR_DATE', NoValue),
69                 defaults = cls.user())
70         return cls.__author
71     @classmethod
72     def committer(cls):
73         if not hasattr(cls, '__committer'):
74             cls.__committer = cls(
75                 name = os.environ.get('GIT_COMMITTER_NAME', NoValue),
76                 email = os.environ.get('GIT_COMMITTER_EMAIL', NoValue),
77                 date = os.environ.get('GIT_COMMITTER_DATE', NoValue),
78                 defaults = cls.user())
79         return cls.__committer
80
81 class Tree(Repr):
82     """Immutable."""
83     def __init__(self, sha1):
84         self.__sha1 = sha1
85     sha1 = property(lambda self: self.__sha1)
86     def __str__(self):
87         return 'Tree<%s>' % self.sha1
88
89 class Commitdata(Repr):
90     """Immutable."""
91     def __init__(self, tree = NoValue, parents = NoValue, author = NoValue,
92                  committer = NoValue, message = NoValue, defaults = NoValue):
93         d = make_defaults(defaults)
94         self.__tree = d(tree, 'tree')
95         self.__parents = d(parents, 'parents')
96         self.__author = d(author, 'author')
97         self.__committer = d(committer, 'committer')
98         self.__message = d(message, 'message')
99     tree = property(lambda self: self.__tree)
100     parents = property(lambda self: self.__parents)
101     @property
102     def parent(self):
103         assert len(self.__parents) == 1
104         return self.__parents[0]
105     author = property(lambda self: self.__author)
106     committer = property(lambda self: self.__committer)
107     message = property(lambda self: self.__message)
108     def set_tree(self, tree):
109         return type(self)(tree = tree, defaults = self)
110     def set_parents(self, parents):
111         return type(self)(parents = parents, defaults = self)
112     def add_parent(self, parent):
113         return type(self)(parents = list(self.parents or []) + [parent],
114                           defaults = self)
115     def set_parent(self, parent):
116         return self.set_parents([parent])
117     def set_author(self, author):
118         return type(self)(author = author, defaults = self)
119     def set_committer(self, committer):
120         return type(self)(committer = committer, defaults = self)
121     def set_message(self, message):
122         return type(self)(message = message, defaults = self)
123     def is_nochange(self):
124         return len(self.parents) == 1 and self.tree == self.parent.data.tree
125     def __str__(self):
126         if self.tree == None:
127             tree = None
128         else:
129             tree = self.tree.sha1
130         if self.parents == None:
131             parents = None
132         else:
133             parents = [p.sha1 for p in self.parents]
134         return ('Commitdata<tree: %s, parents: %s, author: %s,'
135                 ' committer: %s, message: "%s">'
136                 ) % (tree, parents, self.author, self.committer, self.message)
137     @classmethod
138     def parse(cls, repository, s):
139         cd = cls()
140         lines = list(s.splitlines(True))
141         for i in xrange(len(lines)):
142             line = lines[i].strip()
143             if not line:
144                 return cd.set_message(''.join(lines[i+1:]))
145             key, value = line.split(None, 1)
146             if key == 'tree':
147                 cd = cd.set_tree(repository.get_tree(value))
148             elif key == 'parent':
149                 cd = cd.add_parent(repository.get_commit(value))
150             elif key == 'author':
151                 cd = cd.set_author(Person.parse(value))
152             elif key == 'committer':
153                 cd = cd.set_committer(Person.parse(value))
154             else:
155                 assert False
156         assert False
157
158 class Commit(Repr):
159     """Immutable."""
160     def __init__(self, repository, sha1):
161         self.__sha1 = sha1
162         self.__repository = repository
163         self.__data = None
164     sha1 = property(lambda self: self.__sha1)
165     @property
166     def data(self):
167         if self.__data == None:
168             self.__data = Commitdata.parse(
169                 self.__repository,
170                 self.__repository.cat_object(self.sha1))
171         return self.__data
172     def __str__(self):
173         return 'Commit<sha1: %s, data: %s>' % (self.sha1, self.__data)
174
175 class Refs(object):
176     def __init__(self, repository):
177         self.__repository = repository
178         self.__refs = None
179     def __cache_refs(self):
180         self.__refs = {}
181         for line in self.__repository.run(['git', 'show-ref']).output_lines():
182             m = re.match(r'^([0-9a-f]{40})\s+(\S+)$', line)
183             sha1, ref = m.groups()
184             self.__refs[ref] = sha1
185     def get(self, ref):
186         """Throws KeyError if ref doesn't exist."""
187         if self.__refs == None:
188             self.__cache_refs()
189         return self.__repository.get_commit(self.__refs[ref])
190     def exists(self, ref):
191         try:
192             self.get(ref)
193         except KeyError:
194             return False
195         else:
196             return True
197     def set(self, ref, commit, msg):
198         if self.__refs == None:
199             self.__cache_refs()
200         old_sha1 = self.__refs.get(ref, '0'*40)
201         new_sha1 = commit.sha1
202         if old_sha1 != new_sha1:
203             self.__repository.run(['git', 'update-ref', '-m', msg,
204                                    ref, new_sha1, old_sha1]).no_output()
205             self.__refs[ref] = new_sha1
206     def delete(self, ref):
207         if self.__refs == None:
208             self.__cache_refs()
209         self.__repository.run(['git', 'update-ref',
210                                '-d', ref, self.__refs[ref]]).no_output()
211         del self.__refs[ref]
212
213 class ObjectCache(object):
214     """Cache for Python objects, for making sure that we create only one
215     Python object per git object."""
216     def __init__(self, create):
217         self.__objects = {}
218         self.__create = create
219     def __getitem__(self, name):
220         if not name in self.__objects:
221             self.__objects[name] = self.__create(name)
222         return self.__objects[name]
223     def __contains__(self, name):
224         return name in self.__objects
225     def __setitem__(self, name, val):
226         assert not name in self.__objects
227         self.__objects[name] = val
228
229 class RunWithEnv(object):
230     def run(self, args, env = {}):
231         return run.Run(*args).env(utils.add_dict(self.env, env))
232
233 class Repository(RunWithEnv):
234     def __init__(self, directory):
235         self.__git_dir = directory
236         self.__refs = Refs(self)
237         self.__trees = ObjectCache(lambda sha1: Tree(sha1))
238         self.__commits = ObjectCache(lambda sha1: Commit(self, sha1))
239         self.__default_index = None
240         self.__default_worktree = None
241         self.__default_iw = None
242     env = property(lambda self: { 'GIT_DIR': self.__git_dir })
243     @classmethod
244     def default(cls):
245         """Return the default repository."""
246         try:
247             return cls(run.Run('git', 'rev-parse', '--git-dir'
248                                ).output_one_line())
249         except run.RunException:
250             raise RepositoryException('Cannot find git repository')
251     @property
252     def default_index(self):
253         if self.__default_index == None:
254             self.__default_index = Index(
255                 self, (os.environ.get('GIT_INDEX_FILE', None)
256                        or os.path.join(self.__git_dir, 'index')))
257         return self.__default_index
258     def temp_index(self):
259         return Index(self, self.__git_dir)
260     @property
261     def default_worktree(self):
262         if self.__default_worktree == None:
263             path = os.environ.get('GIT_WORK_TREE', None)
264             if not path:
265                 o = run.Run('git', 'rev-parse', '--show-cdup').output_lines()
266                 o = o or ['.']
267                 assert len(o) == 1
268                 path = o[0]
269             self.__default_worktree = Worktree(path)
270         return self.__default_worktree
271     @property
272     def default_iw(self):
273         if self.__default_iw == None:
274             self.__default_iw = IndexAndWorktree(self.default_index,
275                                                  self.default_worktree)
276         return self.__default_iw
277     directory = property(lambda self: self.__git_dir)
278     refs = property(lambda self: self.__refs)
279     def cat_object(self, sha1):
280         return self.run(['git', 'cat-file', '-p', sha1]).raw_output()
281     def rev_parse(self, rev):
282         try:
283             return self.get_commit(self.run(
284                     ['git', 'rev-parse', '%s^{commit}' % rev]
285                     ).output_one_line())
286         except run.RunException:
287             raise RepositoryException('%s: No such revision' % rev)
288     def get_tree(self, sha1):
289         return self.__trees[sha1]
290     def get_commit(self, sha1):
291         return self.__commits[sha1]
292     def commit(self, commitdata):
293         c = ['git', 'commit-tree', commitdata.tree.sha1]
294         for p in commitdata.parents:
295             c.append('-p')
296             c.append(p.sha1)
297         env = {}
298         for p, v1 in ((commitdata.author, 'AUTHOR'),
299                        (commitdata.committer, 'COMMITTER')):
300             if p != None:
301                 for attr, v2 in (('name', 'NAME'), ('email', 'EMAIL'),
302                                  ('date', 'DATE')):
303                     if getattr(p, attr) != None:
304                         env['GIT_%s_%s' % (v1, v2)] = getattr(p, attr)
305         sha1 = self.run(c, env = env).raw_input(commitdata.message
306                                                 ).output_one_line()
307         return self.get_commit(sha1)
308     @property
309     def head(self):
310         try:
311             return self.run(['git', 'symbolic-ref', '-q', 'HEAD']
312                             ).output_one_line()
313         except run.RunException:
314             raise DetachedHeadException()
315     def set_head(self, ref, msg):
316         self.run(['git', 'symbolic-ref', '-m', msg, 'HEAD', ref]).no_output()
317     def simple_merge(self, base, ours, theirs):
318         """Given three trees, tries to do an in-index merge in a temporary
319         index with a temporary index. Returns the result tree, or None if
320         the merge failed (due to conflicts)."""
321         assert isinstance(base, Tree)
322         assert isinstance(ours, Tree)
323         assert isinstance(theirs, Tree)
324
325         # Take care of the really trivial cases.
326         if base == ours:
327             return theirs
328         if base == theirs:
329             return ours
330         if ours == theirs:
331             return ours
332
333         index = self.temp_index()
334         try:
335             index.merge(base, ours, theirs)
336             try:
337                 return index.write_tree()
338             except MergeException:
339                 return None
340         finally:
341             index.delete()
342
343 class MergeException(exception.StgException):
344     pass
345
346 class Index(RunWithEnv):
347     def __init__(self, repository, filename):
348         self.__repository = repository
349         if os.path.isdir(filename):
350             # Create a temp index in the given directory.
351             self.__filename = os.path.join(
352                 filename, 'index.temp-%d-%x' % (os.getpid(), id(self)))
353             self.delete()
354         else:
355             self.__filename = filename
356     env = property(lambda self: utils.add_dict(
357             self.__repository.env, { 'GIT_INDEX_FILE': self.__filename }))
358     def read_tree(self, tree):
359         self.run(['git', 'read-tree', tree.sha1]).no_output()
360     def write_tree(self):
361         try:
362             return self.__repository.get_tree(
363                 self.run(['git', 'write-tree']).discard_stderr(
364                     ).output_one_line())
365         except run.RunException:
366             raise MergeException('Conflicting merge')
367     def is_clean(self):
368         try:
369             self.run(['git', 'update-index', '--refresh']).discard_output()
370         except run.RunException:
371             return False
372         else:
373             return True
374     def merge(self, base, ours, theirs):
375         """In-index merge, no worktree involved."""
376         self.run(['git', 'read-tree', '-m', '-i', '--aggressive',
377                   base.sha1, ours.sha1, theirs.sha1]).no_output()
378     def delete(self):
379         if os.path.isfile(self.__filename):
380             os.remove(self.__filename)
381     def conflicts(self):
382         """The set of conflicting paths."""
383         paths = set()
384         for line in self.run(['git', 'ls-files', '-z', '--unmerged']
385                              ).raw_output().split('\0')[:-1]:
386             stat, path = line.split('\t', 1)
387             paths.add(path)
388         return paths
389
390 class Worktree(object):
391     def __init__(self, directory):
392         self.__directory = directory
393     env = property(lambda self: { 'GIT_WORK_TREE': self.__directory })
394     directory = property(lambda self: self.__directory)
395
396 class CheckoutException(exception.StgException):
397     pass
398
399 class IndexAndWorktree(RunWithEnv):
400     def __init__(self, index, worktree):
401         self.__index = index
402         self.__worktree = worktree
403     index = property(lambda self: self.__index)
404     env = property(lambda self: utils.add_dict(self.__index.env,
405                                                self.__worktree.env))
406     def checkout(self, old_tree, new_tree):
407         # TODO: Optionally do a 3-way instead of doing nothing when we
408         # have a problem. Or maybe we should stash changes in a patch?
409         assert isinstance(old_tree, Tree)
410         assert isinstance(new_tree, Tree)
411         try:
412             self.run(['git', 'read-tree', '-u', '-m',
413                       '--exclude-per-directory=.gitignore',
414                       old_tree.sha1, new_tree.sha1]
415                      ).cwd(self.__worktree.directory).discard_output()
416         except run.RunException:
417             raise CheckoutException('Index/workdir dirty')
418     def merge(self, base, ours, theirs):
419         assert isinstance(base, Tree)
420         assert isinstance(ours, Tree)
421         assert isinstance(theirs, Tree)
422         try:
423             self.run(['git', 'merge-recursive', base.sha1, '--', ours.sha1,
424                       theirs.sha1],
425                      env = { 'GITHEAD_%s' % base.sha1: 'ancestor',
426                              'GITHEAD_%s' % ours.sha1: 'current',
427                              'GITHEAD_%s' % theirs.sha1: 'patched'}
428                      ).cwd(self.__worktree.directory).discard_output()
429         except run.RunException, e:
430             raise MergeException('Index/worktree dirty')
431     def changed_files(self):
432         return self.run(['git', 'diff-files', '--name-only']).output_lines()
433     def update_index(self, files):
434         self.run(['git', 'update-index', '--remove', '-z', '--stdin']
435                  ).input_nulterm(files).discard_output()