chiark / gitweb /
72bf889206bdaf5fc80100c76caf2a179edb3687
[stgit] / stgit / git.py
1 """Python GIT interface
2 """
3
4 __copyright__ = """
5 Copyright (C) 2005, Catalin Marinas <catalin.marinas@gmail.com>
6
7 This program is free software; you can redistribute it and/or modify
8 it under the terms of the GNU General Public License version 2 as
9 published by the Free Software Foundation.
10
11 This program is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14 GNU General Public License for more details.
15
16 You should have received a copy of the GNU General Public License
17 along with this program; if not, write to the Free Software
18 Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
19 """
20
21 import sys, os, popen2, re, gitmergeonefile
22 from shutil import copyfile
23
24 from stgit import basedir
25 from stgit.utils import *
26 from stgit.config import config
27 from sets import Set
28
29 # git exception class
30 class GitException(Exception):
31     pass
32
33
34
35 #
36 # Classes
37 #
38
39 class Person:
40     """An author, committer, etc."""
41     def __init__(self, name = None, email = None, date = '',
42                  desc = None):
43         self.name = self.email = self.date = None
44         if name or email or date:
45             assert not desc
46             self.name = name
47             self.email = email
48             self.date = date
49         elif desc:
50             assert not (name or email or date)
51             def parse_desc(s):
52                 m = re.match(r'^(.+)<(.+)>(.*)$', s)
53                 assert m
54                 return [x.strip() or None for x in m.groups()]
55             self.name, self.email, self.date = parse_desc(desc)
56     def set_name(self, val):
57         if val:
58             self.name = val
59     def set_email(self, val):
60         if val:
61             self.email = val
62     def set_date(self, val):
63         if val:
64             self.date = val
65     def __str__(self):
66         if self.name and self.email:
67             return '%s <%s>' % (self.name, self.email)
68         else:
69             raise GitException, 'not enough identity data'
70
71 class Commit:
72     """Handle the commit objects
73     """
74     def __init__(self, id_hash):
75         self.__id_hash = id_hash
76
77         lines = _output_lines(['git-cat-file', 'commit', id_hash])
78         for i in range(len(lines)):
79             line = lines[i]
80             if line == '\n':
81                 break
82             field = line.strip().split(' ', 1)
83             if field[0] == 'tree':
84                 self.__tree = field[1]
85             if field[0] == 'author':
86                 self.__author = field[1]
87             if field[0] == 'committer':
88                 self.__committer = field[1]
89         self.__log = ''.join(lines[i+1:])
90
91     def get_id_hash(self):
92         return self.__id_hash
93
94     def get_tree(self):
95         return self.__tree
96
97     def get_parent(self):
98         parents = self.get_parents()
99         if parents:
100             return parents[0]
101         else:
102             return None
103
104     def get_parents(self):
105         return _output_lines(['git-rev-list', '--parents', '--max-count=1',
106                               self.__id_hash])[0].split()[1:]
107
108     def get_author(self):
109         return self.__author
110
111     def get_committer(self):
112         return self.__committer
113
114     def get_log(self):
115         return self.__log
116
117     def __str__(self):
118         return self.get_id_hash()
119
120 # dictionary of Commit objects, used to avoid multiple calls to git
121 __commits = dict()
122
123 #
124 # Functions
125 #
126
127 def get_commit(id_hash):
128     """Commit objects factory. Save/look-up them in the __commits
129     dictionary
130     """
131     global __commits
132
133     if id_hash in __commits:
134         return __commits[id_hash]
135     else:
136         commit = Commit(id_hash)
137         __commits[id_hash] = commit
138         return commit
139
140 def get_conflicts():
141     """Return the list of file conflicts
142     """
143     conflicts_file = os.path.join(basedir.get(), 'conflicts')
144     if os.path.isfile(conflicts_file):
145         f = file(conflicts_file)
146         names = [line.strip() for line in f.readlines()]
147         f.close()
148         return names
149     else:
150         return None
151
152 def _input(cmd, file_desc):
153     p = popen2.Popen3(cmd, True)
154     while True:
155         line = file_desc.readline()
156         if not line:
157             break
158         p.tochild.write(line)
159     p.tochild.close()
160     if p.wait():
161         raise GitException, '%s failed (%s)' % (' '.join(cmd),
162                                                 p.childerr.read().strip())
163
164 def _input_str(cmd, string):
165     p = popen2.Popen3(cmd, True)
166     p.tochild.write(string)
167     p.tochild.close()
168     if p.wait():
169         raise GitException, '%s failed (%s)' % (' '.join(cmd),
170                                                 p.childerr.read().strip())
171
172 def _output(cmd):
173     p=popen2.Popen3(cmd, True)
174     output = p.fromchild.read()
175     if p.wait():
176         raise GitException, '%s failed (%s)' % (' '.join(cmd),
177                                                 p.childerr.read().strip())
178     return output
179
180 def _output_one_line(cmd, file_desc = None):
181     p=popen2.Popen3(cmd, True)
182     if file_desc != None:
183         for line in file_desc:
184             p.tochild.write(line)
185         p.tochild.close()
186     output = p.fromchild.readline().strip()
187     if p.wait():
188         raise GitException, '%s failed (%s)' % (' '.join(cmd),
189                                                 p.childerr.read().strip())
190     return output
191
192 def _output_lines(cmd):
193     p=popen2.Popen3(cmd, True)
194     lines = p.fromchild.readlines()
195     if p.wait():
196         raise GitException, '%s failed (%s)' % (' '.join(cmd),
197                                                 p.childerr.read().strip())
198     return lines
199
200 def __run(cmd, args=None):
201     """__run: runs cmd using spawnvp.
202
203     Runs cmd using spawnvp.  The shell is avoided so it won't mess up
204     our arguments.  If args is very large, the command is run multiple
205     times; args is split xargs style: cmd is passed on each
206     invocation.  Unlike xargs, returns immediately if any non-zero
207     return code is received.  
208     """
209     
210     args_l=cmd.split()
211     if args is None:
212         args = []
213     for i in range(0, len(args)+1, 100):
214         r=os.spawnvp(os.P_WAIT, args_l[0], args_l + args[i:min(i+100, len(args))])
215     if r:
216         return r
217     return 0
218
219 def tree_status(files = None, tree_id = 'HEAD', unknown = False,
220                   noexclude = True, verbose = False, diff_flags = []):
221     """Returns a list of pairs - [status, filename]
222     """
223     if verbose:
224         out.start('Checking for changes in the working directory')
225
226     refresh_index()
227
228     if not files:
229         files = []
230     cache_files = []
231
232     # unknown files
233     if unknown:
234         exclude_file = os.path.join(basedir.get(), 'info', 'exclude')
235         base_exclude = ['--exclude=%s' % s for s in
236                         ['*.[ao]', '*.pyc', '.*', '*~', '#*', 'TAGS', 'tags']]
237         base_exclude.append('--exclude-per-directory=.gitignore')
238
239         if os.path.exists(exclude_file):
240             extra_exclude = ['--exclude-from=%s' % exclude_file]
241         else:
242             extra_exclude = []
243         if noexclude:
244             extra_exclude = base_exclude = []
245
246         lines = _output_lines(['git-ls-files', '--others', '--directory']
247                         + base_exclude + extra_exclude)
248         cache_files += [('?', line.strip()) for line in lines]
249
250     # conflicted files
251     conflicts = get_conflicts()
252     if not conflicts:
253         conflicts = []
254     cache_files += [('C', filename) for filename in conflicts]
255
256     # the rest
257     for line in _output_lines(['git-diff-index'] + diff_flags +
258                               [ tree_id, '--'] + files):
259         fs = tuple(line.rstrip().split(' ',4)[-1].split('\t',1))
260         if fs[1] not in conflicts:
261             cache_files.append(fs)
262
263     if verbose:
264         out.done()
265
266     return cache_files
267
268 def local_changes(verbose = True):
269     """Return true if there are local changes in the tree
270     """
271     return len(tree_status(verbose = verbose)) != 0
272
273 # HEAD value cached
274 __head = None
275
276 def get_head():
277     """Verifies the HEAD and returns the SHA1 id that represents it
278     """
279     global __head
280
281     if not __head:
282         __head = rev_parse('HEAD')
283     return __head
284
285 def get_head_file():
286     """Returns the name of the file pointed to by the HEAD link
287     """
288     return strip_prefix('refs/heads/',
289                         _output_one_line(['git-symbolic-ref', 'HEAD']))
290
291 def set_head_file(ref):
292     """Resets HEAD to point to a new ref
293     """
294     # head cache flushing is needed since we might have a different value
295     # in the new head
296     __clear_head_cache()
297     if __run('git-symbolic-ref HEAD',
298              [os.path.join('refs', 'heads', ref)]) != 0:
299         raise GitException, 'Could not set head to "%s"' % ref
300
301 def set_branch(branch, val):
302     """Point branch at a new commit object."""
303     if __run('git-update-ref', [branch, val]) != 0:
304         raise GitException, 'Could not update %s to "%s".' % (branch, val)
305
306 def __set_head(val):
307     """Sets the HEAD value
308     """
309     global __head
310
311     if not __head or __head != val:
312         set_branch('HEAD', val)
313         __head = val
314
315     # only allow SHA1 hashes
316     assert(len(__head) == 40)
317
318 def __clear_head_cache():
319     """Sets the __head to None so that a re-read is forced
320     """
321     global __head
322
323     __head = None
324
325 def refresh_index():
326     """Refresh index with stat() information from the working directory.
327     """
328     __run('git-update-index -q --unmerged --refresh')
329
330 def rev_parse(git_id):
331     """Parse the string and return a verified SHA1 id
332     """
333     try:
334         return _output_one_line(['git-rev-parse', '--verify', git_id])
335     except GitException:
336         raise GitException, 'Unknown revision: %s' % git_id
337
338 def branch_exists(branch):
339     """Existence check for the named branch
340     """
341     branch = os.path.join('refs', 'heads', branch)
342     for line in _output_lines('git-rev-parse --symbolic --all 2>&1'):
343         if line.strip() == branch:
344             return True
345         if re.compile('[ |/]'+branch+' ').search(line):
346             raise GitException, 'Bogus branch: %s' % line
347     return False
348
349 def create_branch(new_branch, tree_id = None):
350     """Create a new branch in the git repository
351     """
352     if branch_exists(new_branch):
353         raise GitException, 'Branch "%s" already exists' % new_branch
354
355     current_head = get_head()
356     set_head_file(new_branch)
357     __set_head(current_head)
358
359     # a checkout isn't needed if new branch points to the current head
360     if tree_id:
361         switch(tree_id)
362
363     if os.path.isfile(os.path.join(basedir.get(), 'MERGE_HEAD')):
364         os.remove(os.path.join(basedir.get(), 'MERGE_HEAD'))
365
366 def switch_branch(new_branch):
367     """Switch to a git branch
368     """
369     global __head
370
371     if not branch_exists(new_branch):
372         raise GitException, 'Branch "%s" does not exist' % new_branch
373
374     tree_id = rev_parse(os.path.join('refs', 'heads', new_branch)
375                         + '^{commit}')
376     if tree_id != get_head():
377         refresh_index()
378         if __run('git-read-tree -u -m', [get_head(), tree_id]) != 0:
379             raise GitException, 'git-read-tree failed (local changes maybe?)'
380         __head = tree_id
381     set_head_file(new_branch)
382
383     if os.path.isfile(os.path.join(basedir.get(), 'MERGE_HEAD')):
384         os.remove(os.path.join(basedir.get(), 'MERGE_HEAD'))
385
386 def delete_branch(name):
387     """Delete a git branch
388     """
389     if not branch_exists(name):
390         raise GitException, 'Branch "%s" does not exist' % name
391     remove_file_and_dirs(os.path.join(basedir.get(), 'refs', 'heads'),
392                          name)
393
394 def rename_branch(from_name, to_name):
395     """Rename a git branch
396     """
397     if not branch_exists(from_name):
398         raise GitException, 'Branch "%s" does not exist' % from_name
399     if branch_exists(to_name):
400         raise GitException, 'Branch "%s" already exists' % to_name
401
402     if get_head_file() == from_name:
403         set_head_file(to_name)
404     rename(os.path.join(basedir.get(), 'refs', 'heads'),
405            from_name, to_name)
406
407     reflog_dir = os.path.join(basedir.get(), 'logs', 'refs', 'heads')
408     if os.path.exists(reflog_dir) \
409            and os.path.exists(os.path.join(reflog_dir, from_name)):
410         rename(reflog_dir, from_name, to_name)
411
412 def add(names):
413     """Add the files or recursively add the directory contents
414     """
415     # generate the file list
416     files = []
417     for i in names:
418         if not os.path.exists(i):
419             raise GitException, 'Unknown file or directory: %s' % i
420
421         if os.path.isdir(i):
422             # recursive search. We only add files
423             for root, dirs, local_files in os.walk(i):
424                 for name in [os.path.join(root, f) for f in local_files]:
425                     if os.path.isfile(name):
426                         files.append(os.path.normpath(name))
427         elif os.path.isfile(i):
428             files.append(os.path.normpath(i))
429         else:
430             raise GitException, '%s is not a file or directory' % i
431
432     if files:
433         if __run('git-update-index --add --', files):
434             raise GitException, 'Unable to add file'
435
436 def __copy_single(source, target, target2=''):
437     """Copy file or dir named 'source' to name target+target2"""
438
439     # "source" (file or dir) must match one or more git-controlled file
440     realfiles = _output_lines(['git-ls-files', source])
441     if len(realfiles) == 0:
442         raise GitException, '"%s" matches no git-controled files' % source
443
444     if os.path.isdir(source):
445         # physically copy the files, and record them to add them in one run
446         newfiles = []
447         re_string='^'+source+'/(.*)$'
448         prefix_regexp = re.compile(re_string)
449         for f in [f.strip() for f in realfiles]:
450             m = prefix_regexp.match(f)
451             if not m:
452                 raise Exception, '"%s" does not match "%s"' % (f, re_string)
453             newname = target+target2+'/'+m.group(1)
454             if not os.path.exists(os.path.dirname(newname)):
455                 os.makedirs(os.path.dirname(newname))
456             copyfile(f, newname)
457             newfiles.append(newname)
458
459         add(newfiles)
460     else: # files, symlinks, ...
461         newname = target+target2
462         copyfile(source, newname)
463         add([newname])
464
465
466 def copy(filespecs, target):
467     if os.path.isdir(target):
468         # target is a directory: copy each entry on the command line,
469         # with the same name, into the target
470         target = target.rstrip('/')
471         
472         # first, check that none of the children of the target
473         # matching the command line aleady exist
474         for filespec in filespecs:
475             entry = target+ '/' + os.path.basename(filespec.rstrip('/'))
476             if os.path.exists(entry):
477                 raise GitException, 'Target "%s" already exists' % entry
478         
479         for filespec in filespecs:
480             filespec = filespec.rstrip('/')
481             basename = '/' + os.path.basename(filespec)
482             __copy_single(filespec, target, basename)
483
484     elif os.path.exists(target):
485         raise GitException, 'Target "%s" exists but is not a directory' % target
486     elif len(filespecs) != 1:
487         raise GitException, 'Cannot copy more than one file to non-directory'
488
489     else:
490         # at this point: len(filespecs)==1 and target does not exist
491
492         # check target directory
493         targetdir = os.path.dirname(target)
494         if targetdir != '' and not os.path.isdir(targetdir):
495             raise GitException, 'Target directory "%s" does not exist' % targetdir
496
497         __copy_single(filespecs[0].rstrip('/'), target)
498         
499
500 def rm(files, force = False):
501     """Remove a file from the repository
502     """
503     if not force:
504         for f in files:
505             if os.path.exists(f):
506                 raise GitException, '%s exists. Remove it first' %f
507         if files:
508             __run('git-update-index --remove --', files)
509     else:
510         if files:
511             __run('git-update-index --force-remove --', files)
512
513 # Persons caching
514 __user = None
515 __author = None
516 __committer = None
517
518 def user():
519     """Return the user information.
520     """
521     global __user
522     if not __user:
523         name=config.get('user.name')
524         email=config.get('user.email')
525         __user = Person(name, email)
526     return __user;
527
528 def author():
529     """Return the author information.
530     """
531     global __author
532     if not __author:
533         try:
534             # the environment variables take priority over config
535             try:
536                 date = os.environ['GIT_AUTHOR_DATE']
537             except KeyError:
538                 date = ''
539             __author = Person(os.environ['GIT_AUTHOR_NAME'],
540                               os.environ['GIT_AUTHOR_EMAIL'],
541                               date)
542         except KeyError:
543             __author = user()
544     return __author
545
546 def committer():
547     """Return the author information.
548     """
549     global __committer
550     if not __committer:
551         try:
552             # the environment variables take priority over config
553             try:
554                 date = os.environ['GIT_COMMITTER_DATE']
555             except KeyError:
556                 date = ''
557             __committer = Person(os.environ['GIT_COMMITTER_NAME'],
558                                  os.environ['GIT_COMMITTER_EMAIL'],
559                                  date)
560         except KeyError:
561             __committer = user()
562     return __committer
563
564 def update_cache(files = None, force = False):
565     """Update the cache information for the given files
566     """
567     if not files:
568         files = []
569
570     cache_files = tree_status(files, verbose = False)
571
572     # everything is up-to-date
573     if len(cache_files) == 0:
574         return False
575
576     # check for unresolved conflicts
577     if not force and [x for x in cache_files
578                       if x[0] not in ['M', 'N', 'A', 'D']]:
579         raise GitException, 'Updating cache failed: unresolved conflicts'
580
581     # update the cache
582     add_files = [x[1] for x in cache_files if x[0] in ['N', 'A']]
583     rm_files =  [x[1] for x in cache_files if x[0] in ['D']]
584     m_files =   [x[1] for x in cache_files if x[0] in ['M']]
585
586     if add_files and __run('git-update-index --add --', add_files) != 0:
587         raise GitException, 'Failed git-update-index --add'
588     if rm_files and __run('git-update-index --force-remove --', rm_files) != 0:
589         raise GitException, 'Failed git-update-index --rm'
590     if m_files and __run('git-update-index --', m_files) != 0:
591         raise GitException, 'Failed git-update-index'
592
593     return True
594
595 def commit(message, files = None, parents = None, allowempty = False,
596            cache_update = True, tree_id = None, set_head = False,
597            author_name = None, author_email = None, author_date = None,
598            committer_name = None, committer_email = None):
599     """Commit the current tree to repository
600     """
601     if not files:
602         files = []
603     if not parents:
604         parents = []
605
606     # Get the tree status
607     if cache_update and parents != []:
608         changes = update_cache(files)
609         if not changes and not allowempty:
610             raise GitException, 'No changes to commit'
611
612     # get the commit message
613     if not message:
614         message = '\n'
615     elif message[-1:] != '\n':
616         message += '\n'
617
618     # write the index to repository
619     if tree_id == None:
620         tree_id = _output_one_line(['git-write-tree'])
621         set_head = True
622
623     # the commit
624     cmd = ['env']
625     if author_name:
626         cmd += ['GIT_AUTHOR_NAME=%s' % author_name]
627     if author_email:
628         cmd += ['GIT_AUTHOR_EMAIL=%s' % author_email]
629     if author_date:
630         cmd += ['GIT_AUTHOR_DATE=%s' % author_date]
631     if committer_name:
632         cmd += ['GIT_COMMITTER_NAME=%s' % committer_name]
633     if committer_email:
634         cmd += ['GIT_COMMITTER_EMAIL=%s' % committer_email]
635     cmd += ['git-commit-tree', tree_id]
636
637     # get the parents
638     for p in parents:
639         cmd += ['-p', p]
640
641     commit_id = _output_one_line(cmd, message)
642     if set_head:
643         __set_head(commit_id)
644
645     return commit_id
646
647 def apply_diff(rev1, rev2, check_index = True, files = None):
648     """Apply the diff between rev1 and rev2 onto the current
649     index. This function doesn't need to raise an exception since it
650     is only used for fast-pushing a patch. If this operation fails,
651     the pushing would fall back to the three-way merge.
652     """
653     if check_index:
654         index_opt = ['--index']
655     else:
656         index_opt = []
657
658     if not files:
659         files = []
660
661     diff_str = diff(files, rev1, rev2)
662     if diff_str:
663         try:
664             _input_str(['git-apply'] + index_opt, diff_str)
665         except GitException:
666             return False
667
668     return True
669
670 def merge(base, head1, head2, recursive = False):
671     """Perform a 3-way merge between base, head1 and head2 into the
672     local tree
673     """
674     refresh_index()
675
676     err_output = None
677     if recursive:
678         # this operation tracks renames but it is slower (used in
679         # general when pushing or picking patches)
680         try:
681             # use _output() to mask the verbose prints of the tool
682             _output(['git-merge-recursive', base, '--', head1, head2])
683         except GitException, ex:
684             err_output = str(ex)
685             pass
686     else:
687         # the fast case where we don't track renames (used when the
688         # distance between base and heads is small, i.e. folding or
689         # synchronising patches)
690         if __run('git-read-tree -u -m --aggressive',
691                  [base, head1, head2]) != 0:
692             raise GitException, 'git-read-tree failed (local changes maybe?)'
693
694     # check the index for unmerged entries
695     files = {}
696     stages_re = re.compile('^([0-7]+) ([0-9a-f]{40}) ([1-3])\t(.*)$', re.S)
697
698     for line in _output(['git-ls-files', '--unmerged', '--stage', '-z']).split('\0'):
699         if not line:
700             continue
701
702         mode, hash, stage, path = stages_re.findall(line)[0]
703
704         if not path in files:
705             files[path] = {}
706             files[path]['1'] = ('', '')
707             files[path]['2'] = ('', '')
708             files[path]['3'] = ('', '')
709
710         files[path][stage] = (mode, hash)
711
712     if err_output and not files:
713         # if no unmerged files, there was probably a different type of
714         # error and we have to abort the merge
715         raise GitException, err_output
716
717     # merge the unmerged files
718     errors = False
719     for path in files:
720         # remove additional files that might be generated for some
721         # newer versions of GIT
722         for suffix in [base, head1, head2]:
723             if not suffix:
724                 continue
725             fname = path + '~' + suffix
726             if os.path.exists(fname):
727                 os.remove(fname)
728
729         stages = files[path]
730         if gitmergeonefile.merge(stages['1'][1], stages['2'][1],
731                                  stages['3'][1], path, stages['1'][0],
732                                  stages['2'][0], stages['3'][0]) != 0:
733             errors = True
734
735     if errors:
736         raise GitException, 'GIT index merging failed (possible conflicts)'
737
738 def status(files = None, modified = False, new = False, deleted = False,
739            conflict = False, unknown = False, noexclude = False,
740            diff_flags = []):
741     """Show the tree status
742     """
743     if not files:
744         files = []
745
746     cache_files = tree_status(files, unknown = True, noexclude = noexclude,
747                                 diff_flags = diff_flags)
748     all = not (modified or new or deleted or conflict or unknown)
749
750     if not all:
751         filestat = []
752         if modified:
753             filestat.append('M')
754         if new:
755             filestat.append('A')
756             filestat.append('N')
757         if deleted:
758             filestat.append('D')
759         if conflict:
760             filestat.append('C')
761         if unknown:
762             filestat.append('?')
763         cache_files = [x for x in cache_files if x[0] in filestat]
764
765     for fs in cache_files:
766         if files and not fs[1] in files:
767             continue
768         if all:
769             out.stdout('%s %s' % (fs[0], fs[1]))
770         else:
771             out.stdout('%s' % fs[1])
772
773 def diff(files = None, rev1 = 'HEAD', rev2 = None, out_fd = None,
774          diff_flags = []):
775     """Show the diff between rev1 and rev2
776     """
777     if not files:
778         files = []
779
780     if rev1 and rev2:
781         diff_str = _output(['git-diff-tree', '-p'] + diff_flags
782                            + [rev1, rev2, '--'] + files)
783     elif rev1 or rev2:
784         refresh_index()
785         if rev2:
786             diff_str = _output(['git-diff-index', '-p', '-R']
787                                + diff_flags + [rev2, '--'] + files)
788         else:
789             diff_str = _output(['git-diff-index', '-p']
790                                + diff_flags + [rev1, '--'] + files)
791     else:
792         diff_str = ''
793
794     if out_fd:
795         out_fd.write(diff_str)
796     else:
797         return diff_str
798
799 def diffstat(files = None, rev1 = 'HEAD', rev2 = None):
800     """Return the diffstat between rev1 and rev2
801     """
802     if not files:
803         files = []
804
805     p=popen2.Popen3('git-apply --stat')
806     diff(files, rev1, rev2, p.tochild)
807     p.tochild.close()
808     diff_str = p.fromchild.read().rstrip()
809     if p.wait():
810         raise GitException, 'git.diffstat failed'
811     return diff_str
812
813 def files(rev1, rev2, diff_flags = []):
814     """Return the files modified between rev1 and rev2
815     """
816
817     result = ''
818     for line in _output_lines(['git-diff-tree'] + diff_flags + ['-r', rev1, rev2]):
819         result += '%s %s\n' % tuple(line.rstrip().split(' ',4)[-1].split('\t',1))
820
821     return result.rstrip()
822
823 def barefiles(rev1, rev2):
824     """Return the files modified between rev1 and rev2, without status info
825     """
826
827     result = ''
828     for line in _output_lines(['git-diff-tree', '-r', rev1, rev2]):
829         result += '%s\n' % line.rstrip().split(' ',4)[-1].split('\t',1)[-1]
830
831     return result.rstrip()
832
833 def pretty_commit(commit_id = 'HEAD', diff_flags = []):
834     """Return a given commit (log + diff)
835     """
836     return _output(['git-diff-tree'] + diff_flags +
837                    ['--cc', '--always', '--pretty', '-r', commit_id])
838
839 def checkout(files = None, tree_id = None, force = False):
840     """Check out the given or all files
841     """
842     if not files:
843         files = []
844
845     if tree_id and __run('git-read-tree --reset', [tree_id]) != 0:
846         raise GitException, 'Failed git-read-tree --reset %s' % tree_id
847
848     checkout_cmd = 'git-checkout-index -q -u'
849     if force:
850         checkout_cmd += ' -f'
851     if len(files) == 0:
852         checkout_cmd += ' -a'
853     else:
854         checkout_cmd += ' --'
855
856     if __run(checkout_cmd, files) != 0:
857         raise GitException, 'Failed git-checkout-index'
858
859 def switch(tree_id, keep = False):
860     """Switch the tree to the given id
861     """
862     if not keep:
863         refresh_index()
864         if __run('git-read-tree -u -m', [get_head(), tree_id]) != 0:
865             raise GitException, 'git-read-tree failed (local changes maybe?)'
866
867     __set_head(tree_id)
868
869 def reset(files = None, tree_id = None, check_out = True):
870     """Revert the tree changes relative to the given tree_id. It removes
871     any local changes
872     """
873     if not tree_id:
874         tree_id = get_head()
875
876     if check_out:
877         cache_files = tree_status(files, tree_id)
878         # files which were added but need to be removed
879         rm_files =  [x[1] for x in cache_files if x[0] in ['A']]
880
881         checkout(files, tree_id, True)
882         # checkout doesn't remove files
883         map(os.remove, rm_files)
884
885     # if the reset refers to the whole tree, switch the HEAD as well
886     if not files:
887         __set_head(tree_id)
888
889 def fetch(repository = 'origin', refspec = None):
890     """Fetches changes from the remote repository, using 'git-fetch'
891     by default.
892     """
893     # we update the HEAD
894     __clear_head_cache()
895
896     args = [repository]
897     if refspec:
898         args.append(refspec)
899
900     command = config.get('branch.%s.stgit.fetchcmd' % get_head_file()) or \
901               config.get('stgit.fetchcmd')
902     if __run(command, args) != 0:
903         raise GitException, 'Failed "%s %s"' % (command, repository)
904
905 def pull(repository = 'origin', refspec = None):
906     """Fetches changes from the remote repository, using 'git-pull'
907     by default.
908     """
909     # we update the HEAD
910     __clear_head_cache()
911
912     args = [repository]
913     if refspec:
914         args.append(refspec)
915
916     command = config.get('branch.%s.stgit.pullcmd' % get_head_file()) or \
917               config.get('stgit.pullcmd')
918     if __run(command, args) != 0:
919         raise GitException, 'Failed "%s %s"' % (command, repository)
920
921 def repack():
922     """Repack all objects into a single pack
923     """
924     __run('git-repack -a -d -f')
925
926 def apply_patch(filename = None, diff = None, base = None,
927                 fail_dump = True):
928     """Apply a patch onto the current or given index. There must not
929     be any local changes in the tree, otherwise the command fails
930     """
931     if diff is None:
932         if filename:
933             f = file(filename)
934         else:
935             f = sys.stdin
936         diff = f.read()
937         if filename:
938             f.close()
939
940     if base:
941         orig_head = get_head()
942         switch(base)
943     else:
944         refresh_index()
945
946     try:
947         _input_str(['git-apply', '--index'], diff)
948     except GitException:
949         if base:
950             switch(orig_head)
951         if fail_dump:
952             # write the failed diff to a file
953             f = file('.stgit-failed.patch', 'w+')
954             f.write(diff)
955             f.close()
956             out.warn('Diff written to the .stgit-failed.patch file')
957
958         raise
959
960     if base:
961         top = commit(message = 'temporary commit used for applying a patch',
962                      parents = [base])
963         switch(orig_head)
964         merge(base, orig_head, top)
965
966 def clone(repository, local_dir):
967     """Clone a remote repository. At the moment, just use the
968     'git-clone' script
969     """
970     if __run('git-clone', [repository, local_dir]) != 0:
971         raise GitException, 'Failed "git-clone %s %s"' \
972               % (repository, local_dir)
973
974 def modifying_revs(files, base_rev, head_rev):
975     """Return the revisions from the list modifying the given files
976     """
977     cmd = ['git-rev-list', '%s..%s' % (base_rev, head_rev), '--']
978     revs = [line.strip() for line in _output_lines(cmd + files)]
979
980     return revs
981
982
983 def refspec_localpart(refspec):
984     m = re.match('^[^:]*:([^:]*)$', refspec)
985     if m:
986         return m.group(1)
987     else:
988         raise GitException, 'Cannot parse refspec "%s"' % line
989
990 def refspec_remotepart(refspec):
991     m = re.match('^([^:]*):[^:]*$', refspec)
992     if m:
993         return m.group(1)
994     else:
995         raise GitException, 'Cannot parse refspec "%s"' % line
996     
997
998 def __remotes_from_config():
999     return config.sections_matching(r'remote\.(.*)\.url')
1000
1001 def __remotes_from_dir(dir):
1002     d = os.path.join(basedir.get(), dir)
1003     if os.path.exists(d):
1004         return os.listdir(d)
1005     else:
1006         return None
1007
1008 def remotes_list():
1009     """Return the list of remotes in the repository
1010     """
1011
1012     return Set(__remotes_from_config()) | \
1013            Set(__remotes_from_dir('remotes')) | \
1014            Set(__remotes_from_dir('branches'))
1015
1016 def remotes_local_branches(remote):
1017     """Returns the list of local branches fetched from given remote
1018     """
1019
1020     branches = []
1021     if remote in __remotes_from_config():
1022         for line in config.getall('remote.%s.fetch' % remote):
1023             branches.append(refspec_localpart(line))
1024     elif remote in __remotes_from_dir('remotes'):
1025         stream = open(os.path.join(basedir.get(), 'remotes', remote), 'r')
1026         for line in stream:
1027             # Only consider Pull lines
1028             m = re.match('^Pull: (.*)\n$', line)
1029             if m:
1030                 branches.append(refspec_localpart(m.group(1)))
1031         stream.close()
1032     elif remote in __remotes_from_dir('branches'):
1033         # old-style branches only declare one branch
1034         branches.append('refs/heads/'+remote);
1035     else:
1036         raise GitException, 'Unknown remote "%s"' % remote
1037
1038     return branches
1039
1040 def identify_remote(branchname):
1041     """Return the name for the remote to pull the given branchname
1042     from, or None if we believe it is a local branch.
1043     """
1044
1045     for remote in remotes_list():
1046         if branchname in remotes_local_branches(remote):
1047             return remote
1048
1049     # if we get here we've found nothing, the branch is a local one
1050     return None
1051
1052 def fetch_head():
1053     """Return the git id for the tip of the parent branch as left by
1054     'git fetch'.
1055     """
1056
1057     fetch_head=None
1058     stream = open(os.path.join(basedir.get(), 'FETCH_HEAD'), "r")
1059     for line in stream:
1060         # Only consider lines not tagged not-for-merge
1061         m = re.match('^([^\t]*)\t\t', line)
1062         if m:
1063             if fetch_head:
1064                 raise GitException, "StGit does not support multiple FETCH_HEAD"
1065             else:
1066                 fetch_head=m.group(1)
1067     stream.close()
1068
1069     # here we are sure to have a single fetch_head
1070     return fetch_head
1071
1072 def all_refs():
1073     """Return a list of all refs in the current repository.
1074     """
1075
1076     return [line.split()[1] for line in _output_lines(['git-show-ref'])]