chiark / gitweb /
7358fae6ef31d11f3f5f912e142b0488b988f21c
[stgit] / stgit / git.py
1 """Python GIT interface
2 """
3
4 __copyright__ = """
5 Copyright (C) 2005, Catalin Marinas <catalin.marinas@gmail.com>
6
7 This program is free software; you can redistribute it and/or modify
8 it under the terms of the GNU General Public License version 2 as
9 published by the Free Software Foundation.
10
11 This program is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14 GNU General Public License for more details.
15
16 You should have received a copy of the GNU General Public License
17 along with this program; if not, write to the Free Software
18 Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
19 """
20
21 import sys, os, popen2, re, gitmergeonefile
22 from shutil import copyfile
23
24 from stgit import basedir
25 from stgit.utils import *
26 from stgit.config import config
27 from sets import Set
28
29 # git exception class
30 class GitException(Exception):
31     pass
32
33
34
35 #
36 # Classes
37 #
38
39 class Person:
40     """An author, committer, etc."""
41     def __init__(self, name = None, email = None, date = '',
42                  desc = None):
43         self.name = self.email = self.date = None
44         if name or email or date:
45             assert not desc
46             self.name = name
47             self.email = email
48             self.date = date
49         elif desc:
50             assert not (name or email or date)
51             def parse_desc(s):
52                 m = re.match(r'^(.+)<(.+)>(.*)$', s)
53                 assert m
54                 return [x.strip() or None for x in m.groups()]
55             self.name, self.email, self.date = parse_desc(desc)
56     def set_name(self, val):
57         if val:
58             self.name = val
59     def set_email(self, val):
60         if val:
61             self.email = val
62     def set_date(self, val):
63         if val:
64             self.date = val
65     def __str__(self):
66         if self.name and self.email:
67             return '%s <%s>' % (self.name, self.email)
68         else:
69             raise GitException, 'not enough identity data'
70
71 class Commit:
72     """Handle the commit objects
73     """
74     def __init__(self, id_hash):
75         self.__id_hash = id_hash
76
77         lines = _output_lines(['git-cat-file', 'commit', id_hash])
78         for i in range(len(lines)):
79             line = lines[i]
80             if line == '\n':
81                 break
82             field = line.strip().split(' ', 1)
83             if field[0] == 'tree':
84                 self.__tree = field[1]
85             if field[0] == 'author':
86                 self.__author = field[1]
87             if field[0] == 'committer':
88                 self.__committer = field[1]
89         self.__log = ''.join(lines[i+1:])
90
91     def get_id_hash(self):
92         return self.__id_hash
93
94     def get_tree(self):
95         return self.__tree
96
97     def get_parent(self):
98         parents = self.get_parents()
99         if parents:
100             return parents[0]
101         else:
102             return None
103
104     def get_parents(self):
105         return _output_lines(['git-rev-list', '--parents', '--max-count=1',
106                               self.__id_hash])[0].split()[1:]
107
108     def get_author(self):
109         return self.__author
110
111     def get_committer(self):
112         return self.__committer
113
114     def get_log(self):
115         return self.__log
116
117     def __str__(self):
118         return self.get_id_hash()
119
120 # dictionary of Commit objects, used to avoid multiple calls to git
121 __commits = dict()
122
123 #
124 # Functions
125 #
126
127 def get_commit(id_hash):
128     """Commit objects factory. Save/look-up them in the __commits
129     dictionary
130     """
131     global __commits
132
133     if id_hash in __commits:
134         return __commits[id_hash]
135     else:
136         commit = Commit(id_hash)
137         __commits[id_hash] = commit
138         return commit
139
140 def get_conflicts():
141     """Return the list of file conflicts
142     """
143     conflicts_file = os.path.join(basedir.get(), 'conflicts')
144     if os.path.isfile(conflicts_file):
145         f = file(conflicts_file)
146         names = [line.strip() for line in f.readlines()]
147         f.close()
148         return names
149     else:
150         return None
151
152 def _input(cmd, file_desc):
153     p = popen2.Popen3(cmd, True)
154     while True:
155         line = file_desc.readline()
156         if not line:
157             break
158         p.tochild.write(line)
159     p.tochild.close()
160     if p.wait():
161         raise GitException, '%s failed (%s)' % (str(cmd),
162                                                 p.childerr.read().strip())
163
164 def _input_str(cmd, string):
165     p = popen2.Popen3(cmd, True)
166     p.tochild.write(string)
167     p.tochild.close()
168     if p.wait():
169         raise GitException, '%s failed (%s)' % (str(cmd),
170                                                 p.childerr.read().strip())
171
172 def _output(cmd):
173     p=popen2.Popen3(cmd, True)
174     output = p.fromchild.read()
175     if p.wait():
176         raise GitException, '%s failed (%s)' % (str(cmd),
177                                                 p.childerr.read().strip())
178     return output
179
180 def _output_one_line(cmd, file_desc = None):
181     p=popen2.Popen3(cmd, True)
182     if file_desc != None:
183         for line in file_desc:
184             p.tochild.write(line)
185         p.tochild.close()
186     output = p.fromchild.readline().strip()
187     if p.wait():
188         raise GitException, '%s failed (%s)' % (str(cmd),
189                                                 p.childerr.read().strip())
190     return output
191
192 def _output_lines(cmd):
193     p=popen2.Popen3(cmd, True)
194     lines = p.fromchild.readlines()
195     if p.wait():
196         raise GitException, '%s failed (%s)' % (str(cmd),
197                                                 p.childerr.read().strip())
198     return lines
199
200 def __run(cmd, args=None):
201     """__run: runs cmd using spawnvp.
202
203     Runs cmd using spawnvp.  The shell is avoided so it won't mess up
204     our arguments.  If args is very large, the command is run multiple
205     times; args is split xargs style: cmd is passed on each
206     invocation.  Unlike xargs, returns immediately if any non-zero
207     return code is received.  
208     """
209     
210     args_l=cmd.split()
211     if args is None:
212         args = []
213     for i in range(0, len(args)+1, 100):
214         r=os.spawnvp(os.P_WAIT, args_l[0], args_l + args[i:min(i+100, len(args))])
215     if r:
216         return r
217     return 0
218
219 def __tree_status(files = None, tree_id = 'HEAD', unknown = False,
220                   noexclude = True, verbose = False):
221     """Returns a list of pairs - [status, filename]
222     """
223     if verbose:
224         out.start('Checking for changes in the working directory')
225
226     refresh_index()
227
228     if not files:
229         files = []
230     cache_files = []
231
232     # unknown files
233     if unknown:
234         exclude_file = os.path.join(basedir.get(), 'info', 'exclude')
235         base_exclude = ['--exclude=%s' % s for s in
236                         ['*.[ao]', '*.pyc', '.*', '*~', '#*', 'TAGS', 'tags']]
237         base_exclude.append('--exclude-per-directory=.gitignore')
238
239         if os.path.exists(exclude_file):
240             extra_exclude = ['--exclude-from=%s' % exclude_file]
241         else:
242             extra_exclude = []
243         if noexclude:
244             extra_exclude = base_exclude = []
245
246         lines = _output_lines(['git-ls-files', '--others', '--directory']
247                         + base_exclude + extra_exclude)
248         cache_files += [('?', line.strip()) for line in lines]
249
250     # conflicted files
251     conflicts = get_conflicts()
252     if not conflicts:
253         conflicts = []
254     cache_files += [('C', filename) for filename in conflicts]
255
256     # the rest
257     for line in _output_lines(['git-diff-index', tree_id, '--'] + files):
258         fs = tuple(line.rstrip().split(' ',4)[-1].split('\t',1))
259         if fs[1] not in conflicts:
260             cache_files.append(fs)
261
262     if verbose:
263         out.done()
264
265     return cache_files
266
267 def local_changes(verbose = True):
268     """Return true if there are local changes in the tree
269     """
270     return len(__tree_status(verbose = verbose)) != 0
271
272 # HEAD value cached
273 __head = None
274
275 def get_head():
276     """Verifies the HEAD and returns the SHA1 id that represents it
277     """
278     global __head
279
280     if not __head:
281         __head = rev_parse('HEAD')
282     return __head
283
284 def get_head_file():
285     """Returns the name of the file pointed to by the HEAD link
286     """
287     return strip_prefix('refs/heads/',
288                         _output_one_line(['git-symbolic-ref', 'HEAD']))
289
290 def set_head_file(ref):
291     """Resets HEAD to point to a new ref
292     """
293     # head cache flushing is needed since we might have a different value
294     # in the new head
295     __clear_head_cache()
296     if __run('git-symbolic-ref HEAD',
297              [os.path.join('refs', 'heads', ref)]) != 0:
298         raise GitException, 'Could not set head to "%s"' % ref
299
300 def set_branch(branch, val):
301     """Point branch at a new commit object."""
302     if __run('git-update-ref', [branch, val]) != 0:
303         raise GitException, 'Could not update %s to "%s".' % (branch, val)
304
305 def __set_head(val):
306     """Sets the HEAD value
307     """
308     global __head
309
310     if not __head or __head != val:
311         set_branch('HEAD', val)
312         __head = val
313
314     # only allow SHA1 hashes
315     assert(len(__head) == 40)
316
317 def __clear_head_cache():
318     """Sets the __head to None so that a re-read is forced
319     """
320     global __head
321
322     __head = None
323
324 def refresh_index():
325     """Refresh index with stat() information from the working directory.
326     """
327     __run('git-update-index -q --unmerged --refresh')
328
329 def rev_parse(git_id):
330     """Parse the string and return a verified SHA1 id
331     """
332     try:
333         return _output_one_line(['git-rev-parse', '--verify', git_id])
334     except GitException:
335         raise GitException, 'Unknown revision: %s' % git_id
336
337 def branch_exists(branch):
338     """Existence check for the named branch
339     """
340     branch = os.path.join('refs', 'heads', branch)
341     for line in _output_lines('git-rev-parse --symbolic --all 2>&1'):
342         if line.strip() == branch:
343             return True
344         if re.compile('[ |/]'+branch+' ').search(line):
345             raise GitException, 'Bogus branch: %s' % line
346     return False
347
348 def create_branch(new_branch, tree_id = None):
349     """Create a new branch in the git repository
350     """
351     if branch_exists(new_branch):
352         raise GitException, 'Branch "%s" already exists' % new_branch
353
354     current_head = get_head()
355     set_head_file(new_branch)
356     __set_head(current_head)
357
358     # a checkout isn't needed if new branch points to the current head
359     if tree_id:
360         switch(tree_id)
361
362     if os.path.isfile(os.path.join(basedir.get(), 'MERGE_HEAD')):
363         os.remove(os.path.join(basedir.get(), 'MERGE_HEAD'))
364
365 def switch_branch(new_branch):
366     """Switch to a git branch
367     """
368     global __head
369
370     if not branch_exists(new_branch):
371         raise GitException, 'Branch "%s" does not exist' % new_branch
372
373     tree_id = rev_parse(os.path.join('refs', 'heads', new_branch)
374                         + '^{commit}')
375     if tree_id != get_head():
376         refresh_index()
377         if __run('git-read-tree -u -m', [get_head(), tree_id]) != 0:
378             raise GitException, 'git-read-tree failed (local changes maybe?)'
379         __head = tree_id
380     set_head_file(new_branch)
381
382     if os.path.isfile(os.path.join(basedir.get(), 'MERGE_HEAD')):
383         os.remove(os.path.join(basedir.get(), 'MERGE_HEAD'))
384
385 def delete_branch(name):
386     """Delete a git branch
387     """
388     if not branch_exists(name):
389         raise GitException, 'Branch "%s" does not exist' % name
390     remove_file_and_dirs(os.path.join(basedir.get(), 'refs', 'heads'),
391                          name)
392
393 def rename_branch(from_name, to_name):
394     """Rename a git branch
395     """
396     if not branch_exists(from_name):
397         raise GitException, 'Branch "%s" does not exist' % from_name
398     if branch_exists(to_name):
399         raise GitException, 'Branch "%s" already exists' % to_name
400
401     if get_head_file() == from_name:
402         set_head_file(to_name)
403     rename(os.path.join(basedir.get(), 'refs', 'heads'),
404            from_name, to_name)
405
406     reflog_dir = os.path.join(basedir.get(), 'logs', 'refs', 'heads')
407     if os.path.exists(reflog_dir) \
408            and os.path.exists(os.path.join(reflog_dir, from_name)):
409         rename(reflog_dir, from_name, to_name)
410
411 def add(names):
412     """Add the files or recursively add the directory contents
413     """
414     # generate the file list
415     files = []
416     for i in names:
417         if not os.path.exists(i):
418             raise GitException, 'Unknown file or directory: %s' % i
419
420         if os.path.isdir(i):
421             # recursive search. We only add files
422             for root, dirs, local_files in os.walk(i):
423                 for name in [os.path.join(root, f) for f in local_files]:
424                     if os.path.isfile(name):
425                         files.append(os.path.normpath(name))
426         elif os.path.isfile(i):
427             files.append(os.path.normpath(i))
428         else:
429             raise GitException, '%s is not a file or directory' % i
430
431     if files:
432         if __run('git-update-index --add --', files):
433             raise GitException, 'Unable to add file'
434
435 def __copy_single(source, target, target2=''):
436     """Copy file or dir named 'source' to name target+target2"""
437
438     # "source" (file or dir) must match one or more git-controlled file
439     realfiles = _output_lines(['git-ls-files', source])
440     if len(realfiles) == 0:
441         raise GitException, '"%s" matches no git-controled files' % source
442
443     if os.path.isdir(source):
444         # physically copy the files, and record them to add them in one run
445         newfiles = []
446         re_string='^'+source+'/(.*)$'
447         prefix_regexp = re.compile(re_string)
448         for f in [f.strip() for f in realfiles]:
449             m = prefix_regexp.match(f)
450             if not m:
451                 raise Exception, '"%s" does not match "%s"' % (f, re_string)
452             newname = target+target2+'/'+m.group(1)
453             if not os.path.exists(os.path.dirname(newname)):
454                 os.makedirs(os.path.dirname(newname))
455             copyfile(f, newname)
456             newfiles.append(newname)
457
458         add(newfiles)
459     else: # files, symlinks, ...
460         newname = target+target2
461         copyfile(source, newname)
462         add([newname])
463
464
465 def copy(filespecs, target):
466     if os.path.isdir(target):
467         # target is a directory: copy each entry on the command line,
468         # with the same name, into the target
469         target = target.rstrip('/')
470         
471         # first, check that none of the children of the target
472         # matching the command line aleady exist
473         for filespec in filespecs:
474             entry = target+ '/' + os.path.basename(filespec.rstrip('/'))
475             if os.path.exists(entry):
476                 raise GitException, 'Target "%s" already exists' % entry
477         
478         for filespec in filespecs:
479             filespec = filespec.rstrip('/')
480             basename = '/' + os.path.basename(filespec)
481             __copy_single(filespec, target, basename)
482
483     elif os.path.exists(target):
484         raise GitException, 'Target "%s" exists but is not a directory' % target
485     elif len(filespecs) != 1:
486         raise GitException, 'Cannot copy more than one file to non-directory'
487
488     else:
489         # at this point: len(filespecs)==1 and target does not exist
490
491         # check target directory
492         targetdir = os.path.dirname(target)
493         if targetdir != '' and not os.path.isdir(targetdir):
494             raise GitException, 'Target directory "%s" does not exist' % targetdir
495
496         __copy_single(filespecs[0].rstrip('/'), target)
497         
498
499 def rm(files, force = False):
500     """Remove a file from the repository
501     """
502     if not force:
503         for f in files:
504             if os.path.exists(f):
505                 raise GitException, '%s exists. Remove it first' %f
506         if files:
507             __run('git-update-index --remove --', files)
508     else:
509         if files:
510             __run('git-update-index --force-remove --', files)
511
512 # Persons caching
513 __user = None
514 __author = None
515 __committer = None
516
517 def user():
518     """Return the user information.
519     """
520     global __user
521     if not __user:
522         name=config.get('user.name')
523         email=config.get('user.email')
524         __user = Person(name, email)
525     return __user;
526
527 def author():
528     """Return the author information.
529     """
530     global __author
531     if not __author:
532         try:
533             # the environment variables take priority over config
534             try:
535                 date = os.environ['GIT_AUTHOR_DATE']
536             except KeyError:
537                 date = ''
538             __author = Person(os.environ['GIT_AUTHOR_NAME'],
539                               os.environ['GIT_AUTHOR_EMAIL'],
540                               date)
541         except KeyError:
542             __author = user()
543     return __author
544
545 def committer():
546     """Return the author information.
547     """
548     global __committer
549     if not __committer:
550         try:
551             # the environment variables take priority over config
552             try:
553                 date = os.environ['GIT_COMMITTER_DATE']
554             except KeyError:
555                 date = ''
556             __committer = Person(os.environ['GIT_COMMITTER_NAME'],
557                                  os.environ['GIT_COMMITTER_EMAIL'],
558                                  date)
559         except KeyError:
560             __committer = user()
561     return __committer
562
563 def update_cache(files = None, force = False):
564     """Update the cache information for the given files
565     """
566     if not files:
567         files = []
568
569     cache_files = __tree_status(files, verbose = False)
570
571     # everything is up-to-date
572     if len(cache_files) == 0:
573         return False
574
575     # check for unresolved conflicts
576     if not force and [x for x in cache_files
577                       if x[0] not in ['M', 'N', 'A', 'D']]:
578         raise GitException, 'Updating cache failed: unresolved conflicts'
579
580     # update the cache
581     add_files = [x[1] for x in cache_files if x[0] in ['N', 'A']]
582     rm_files =  [x[1] for x in cache_files if x[0] in ['D']]
583     m_files =   [x[1] for x in cache_files if x[0] in ['M']]
584
585     if add_files and __run('git-update-index --add --', add_files) != 0:
586         raise GitException, 'Failed git-update-index --add'
587     if rm_files and __run('git-update-index --force-remove --', rm_files) != 0:
588         raise GitException, 'Failed git-update-index --rm'
589     if m_files and __run('git-update-index --', m_files) != 0:
590         raise GitException, 'Failed git-update-index'
591
592     return True
593
594 def commit(message, files = None, parents = None, allowempty = False,
595            cache_update = True, tree_id = None,
596            author_name = None, author_email = None, author_date = None,
597            committer_name = None, committer_email = None):
598     """Commit the current tree to repository
599     """
600     if not files:
601         files = []
602     if not parents:
603         parents = []
604
605     # Get the tree status
606     if cache_update and parents != []:
607         changes = update_cache(files)
608         if not changes and not allowempty:
609             raise GitException, 'No changes to commit'
610
611     # get the commit message
612     if not message:
613         message = '\n'
614     elif message[-1:] != '\n':
615         message += '\n'
616
617     must_switch = True
618     # write the index to repository
619     if tree_id == None:
620         tree_id = _output_one_line(['git-write-tree'])
621     else:
622         must_switch = False
623
624     # the commit
625     cmd = ['env']
626     if author_name:
627         cmd += ['GIT_AUTHOR_NAME=%s' % author_name]
628     if author_email:
629         cmd += ['GIT_AUTHOR_EMAIL=%s' % author_email]
630     if author_date:
631         cmd += ['GIT_AUTHOR_DATE=%s' % author_date]
632     if committer_name:
633         cmd += ['GIT_COMMITTER_NAME=%s' % committer_name]
634     if committer_email:
635         cmd += ['GIT_COMMITTER_EMAIL=%s' % committer_email]
636     cmd += ['git-commit-tree', tree_id]
637
638     # get the parents
639     for p in parents:
640         cmd += ['-p', p]
641
642     commit_id = _output_one_line(cmd, message)
643     if must_switch:
644         __set_head(commit_id)
645
646     return commit_id
647
648 def apply_diff(rev1, rev2, check_index = True, files = None):
649     """Apply the diff between rev1 and rev2 onto the current
650     index. This function doesn't need to raise an exception since it
651     is only used for fast-pushing a patch. If this operation fails,
652     the pushing would fall back to the three-way merge.
653     """
654     if check_index:
655         index_opt = ['--index']
656     else:
657         index_opt = []
658
659     if not files:
660         files = []
661
662     diff_str = diff(files, rev1, rev2)
663     if diff_str:
664         try:
665             _input_str(['git-apply'] + index_opt, diff_str)
666         except GitException:
667             return False
668
669     return True
670
671 def merge(base, head1, head2, recursive = False):
672     """Perform a 3-way merge between base, head1 and head2 into the
673     local tree
674     """
675     refresh_index()
676
677     err_output = None
678     if recursive:
679         # this operation tracks renames but it is slower (used in
680         # general when pushing or picking patches)
681         try:
682             # use _output() to mask the verbose prints of the tool
683             _output(['git-merge-recursive', base, '--', head1, head2])
684         except GitException, ex:
685             err_output = str(ex)
686             pass
687     else:
688         # the fast case where we don't track renames (used when the
689         # distance between base and heads is small, i.e. folding or
690         # synchronising patches)
691         if __run('git-read-tree -u -m --aggressive',
692                  [base, head1, head2]) != 0:
693             raise GitException, 'git-read-tree failed (local changes maybe?)'
694
695     # check the index for unmerged entries
696     files = {}
697     stages_re = re.compile('^([0-7]+) ([0-9a-f]{40}) ([1-3])\t(.*)$', re.S)
698
699     for line in _output(['git-ls-files', '--unmerged', '--stage', '-z']).split('\0'):
700         if not line:
701             continue
702
703         mode, hash, stage, path = stages_re.findall(line)[0]
704
705         if not path in files:
706             files[path] = {}
707             files[path]['1'] = ('', '')
708             files[path]['2'] = ('', '')
709             files[path]['3'] = ('', '')
710
711         files[path][stage] = (mode, hash)
712
713     if err_output and not files:
714         # if no unmerged files, there was probably a different type of
715         # error and we have to abort the merge
716         raise GitException, err_output
717
718     # merge the unmerged files
719     errors = False
720     for path in files:
721         # remove additional files that might be generated for some
722         # newer versions of GIT
723         for suffix in [base, head1, head2]:
724             if not suffix:
725                 continue
726             fname = path + '~' + suffix
727             if os.path.exists(fname):
728                 os.remove(fname)
729
730         stages = files[path]
731         if gitmergeonefile.merge(stages['1'][1], stages['2'][1],
732                                  stages['3'][1], path, stages['1'][0],
733                                  stages['2'][0], stages['3'][0]) != 0:
734             errors = True
735
736     if errors:
737         raise GitException, 'GIT index merging failed (possible conflicts)'
738
739 def status(files = None, modified = False, new = False, deleted = False,
740            conflict = False, unknown = False, noexclude = False):
741     """Show the tree status
742     """
743     if not files:
744         files = []
745
746     cache_files = __tree_status(files, unknown = True, noexclude = noexclude)
747     all = not (modified or new or deleted or conflict or unknown)
748
749     if not all:
750         filestat = []
751         if modified:
752             filestat.append('M')
753         if new:
754             filestat.append('A')
755             filestat.append('N')
756         if deleted:
757             filestat.append('D')
758         if conflict:
759             filestat.append('C')
760         if unknown:
761             filestat.append('?')
762         cache_files = [x for x in cache_files if x[0] in filestat]
763
764     for fs in cache_files:
765         if files and not fs[1] in files:
766             continue
767         if all:
768             out.stdout('%s %s' % (fs[0], fs[1]))
769         else:
770             out.stdout('%s' % fs[1])
771
772 def diff(files = None, rev1 = 'HEAD', rev2 = None, out_fd = None,
773          diff_flags = []):
774     """Show the diff between rev1 and rev2
775     """
776     if not files:
777         files = []
778
779     if rev1 and rev2:
780         diff_str = _output(['git-diff-tree', '-p'] + diff_flags
781                            + [rev1, rev2, '--'] + files)
782     elif rev1 or rev2:
783         refresh_index()
784         if rev2:
785             diff_str = _output(['git-diff-index', '-p', '-R']
786                                + diff_flags + [rev2, '--'] + files)
787         else:
788             diff_str = _output(['git-diff-index', '-p']
789                                + diff_flags + [rev1, '--'] + files)
790     else:
791         diff_str = ''
792
793     if out_fd:
794         out_fd.write(diff_str)
795     else:
796         return diff_str
797
798 def diffstat(files = None, rev1 = 'HEAD', rev2 = None):
799     """Return the diffstat between rev1 and rev2
800     """
801     if not files:
802         files = []
803
804     p=popen2.Popen3('git-apply --stat')
805     diff(files, rev1, rev2, p.tochild)
806     p.tochild.close()
807     diff_str = p.fromchild.read().rstrip()
808     if p.wait():
809         raise GitException, 'git.diffstat failed'
810     return diff_str
811
812 def files(rev1, rev2):
813     """Return the files modified between rev1 and rev2
814     """
815
816     result = ''
817     for line in _output_lines(['git-diff-tree', '-r', rev1, rev2]):
818         result += '%s %s\n' % tuple(line.rstrip().split(' ',4)[-1].split('\t',1))
819
820     return result.rstrip()
821
822 def barefiles(rev1, rev2):
823     """Return the files modified between rev1 and rev2, without status info
824     """
825
826     result = ''
827     for line in _output_lines(['git-diff-tree', '-r', rev1, rev2]):
828         result += '%s\n' % line.rstrip().split(' ',4)[-1].split('\t',1)[-1]
829
830     return result.rstrip()
831
832 def pretty_commit(commit_id = 'HEAD'):
833     """Return a given commit (log + diff)
834     """
835     return _output(['git-diff-tree', '--cc', '--always', '--pretty', '-r',
836                     commit_id])
837
838 def checkout(files = None, tree_id = None, force = False):
839     """Check out the given or all files
840     """
841     if not files:
842         files = []
843
844     if tree_id and __run('git-read-tree --reset', [tree_id]) != 0:
845         raise GitException, 'Failed git-read-tree --reset %s' % tree_id
846
847     checkout_cmd = 'git-checkout-index -q -u'
848     if force:
849         checkout_cmd += ' -f'
850     if len(files) == 0:
851         checkout_cmd += ' -a'
852     else:
853         checkout_cmd += ' --'
854
855     if __run(checkout_cmd, files) != 0:
856         raise GitException, 'Failed git-checkout-index'
857
858 def switch(tree_id, keep = False):
859     """Switch the tree to the given id
860     """
861     if not keep:
862         refresh_index()
863         if __run('git-read-tree -u -m', [get_head(), tree_id]) != 0:
864             raise GitException, 'git-read-tree failed (local changes maybe?)'
865
866     __set_head(tree_id)
867
868 def reset(files = None, tree_id = None, check_out = True):
869     """Revert the tree changes relative to the given tree_id. It removes
870     any local changes
871     """
872     if not tree_id:
873         tree_id = get_head()
874
875     if check_out:
876         cache_files = __tree_status(files, tree_id)
877         # files which were added but need to be removed
878         rm_files =  [x[1] for x in cache_files if x[0] in ['A']]
879
880         checkout(files, tree_id, True)
881         # checkout doesn't remove files
882         map(os.remove, rm_files)
883
884     # if the reset refers to the whole tree, switch the HEAD as well
885     if not files:
886         __set_head(tree_id)
887
888 def fetch(repository = 'origin', refspec = None):
889     """Fetches changes from the remote repository, using 'git-fetch'
890     by default.
891     """
892     # we update the HEAD
893     __clear_head_cache()
894
895     args = [repository]
896     if refspec:
897         args.append(refspec)
898
899     command = config.get('branch.%s.stgit.fetchcmd' % get_head_file()) or \
900               config.get('stgit.fetchcmd')
901     if __run(command, args) != 0:
902         raise GitException, 'Failed "%s %s"' % (command, repository)
903
904 def pull(repository = 'origin', refspec = None):
905     """Fetches changes from the remote repository, using 'git-pull'
906     by default.
907     """
908     # we update the HEAD
909     __clear_head_cache()
910
911     args = [repository]
912     if refspec:
913         args.append(refspec)
914
915     command = config.get('branch.%s.stgit.pullcmd' % get_head_file()) or \
916               config.get('stgit.pullcmd')
917     if __run(command, args) != 0:
918         raise GitException, 'Failed "%s %s"' % (command, repository)
919
920 def repack():
921     """Repack all objects into a single pack
922     """
923     __run('git-repack -a -d -f')
924
925 def apply_patch(filename = None, diff = None, base = None,
926                 fail_dump = True):
927     """Apply a patch onto the current or given index. There must not
928     be any local changes in the tree, otherwise the command fails
929     """
930     if diff is None:
931         if filename:
932             f = file(filename)
933         else:
934             f = sys.stdin
935         diff = f.read()
936         if filename:
937             f.close()
938
939     if base:
940         orig_head = get_head()
941         switch(base)
942     else:
943         refresh_index()
944
945     try:
946         _input_str(['git-apply', '--index'], diff)
947     except GitException:
948         if base:
949             switch(orig_head)
950         if fail_dump:
951             # write the failed diff to a file
952             f = file('.stgit-failed.patch', 'w+')
953             f.write(diff)
954             f.close()
955             out.warn('Diff written to the .stgit-failed.patch file')
956
957         raise
958
959     if base:
960         top = commit(message = 'temporary commit used for applying a patch',
961                      parents = [base])
962         switch(orig_head)
963         merge(base, orig_head, top)
964
965 def clone(repository, local_dir):
966     """Clone a remote repository. At the moment, just use the
967     'git-clone' script
968     """
969     if __run('git-clone', [repository, local_dir]) != 0:
970         raise GitException, 'Failed "git-clone %s %s"' \
971               % (repository, local_dir)
972
973 def modifying_revs(files, base_rev, head_rev):
974     """Return the revisions from the list modifying the given files
975     """
976     cmd = ['git-rev-list', '%s..%s' % (base_rev, head_rev), '--']
977     revs = [line.strip() for line in _output_lines(cmd + files)]
978
979     return revs
980
981
982 def refspec_localpart(refspec):
983     m = re.match('^[^:]*:([^:]*)$', refspec)
984     if m:
985         return m.group(1)
986     else:
987         raise GitException, 'Cannot parse refspec "%s"' % line
988
989 def refspec_remotepart(refspec):
990     m = re.match('^([^:]*):[^:]*$', refspec)
991     if m:
992         return m.group(1)
993     else:
994         raise GitException, 'Cannot parse refspec "%s"' % line
995     
996
997 def __remotes_from_config():
998     return config.sections_matching(r'remote\.(.*)\.url')
999
1000 def __remotes_from_dir(dir):
1001     d = os.path.join(basedir.get(), dir)
1002     if os.path.exists(d):
1003         return os.listdir(d)
1004     else:
1005         return None
1006
1007 def remotes_list():
1008     """Return the list of remotes in the repository
1009     """
1010
1011     return Set(__remotes_from_config()) | \
1012            Set(__remotes_from_dir('remotes')) | \
1013            Set(__remotes_from_dir('branches'))
1014
1015 def remotes_local_branches(remote):
1016     """Returns the list of local branches fetched from given remote
1017     """
1018
1019     branches = []
1020     if remote in __remotes_from_config():
1021         for line in config.getall('remote.%s.fetch' % remote):
1022             branches.append(refspec_localpart(line))
1023     elif remote in __remotes_from_dir('remotes'):
1024         stream = open(os.path.join(basedir.get(), 'remotes', remote), 'r')
1025         for line in stream:
1026             # Only consider Pull lines
1027             m = re.match('^Pull: (.*)\n$', line)
1028             if m:
1029                 branches.append(refspec_localpart(m.group(1)))
1030         stream.close()
1031     elif remote in __remotes_from_dir('branches'):
1032         # old-style branches only declare one branch
1033         branches.append('refs/heads/'+remote);
1034     else:
1035         raise GitException, 'Unknown remote "%s"' % remote
1036
1037     return branches
1038
1039 def identify_remote(branchname):
1040     """Return the name for the remote to pull the given branchname
1041     from, or None if we believe it is a local branch.
1042     """
1043
1044     for remote in remotes_list():
1045         if branchname in remotes_local_branches(remote):
1046             return remote
1047
1048     # if we get here we've found nothing, the branch is a local one
1049     return None
1050
1051 def fetch_head():
1052     """Return the git id for the tip of the parent branch as left by
1053     'git fetch'.
1054     """
1055
1056     fetch_head=None
1057     stream = open(os.path.join(basedir.get(), 'FETCH_HEAD'), "r")
1058     for line in stream:
1059         # Only consider lines not tagged not-for-merge
1060         m = re.match('^([^\t]*)\t\t', line)
1061         if m:
1062             if fetch_head:
1063                 raise GitException, "StGit does not support multiple FETCH_HEAD"
1064             else:
1065                 fetch_head=m.group(1)
1066     stream.close()
1067
1068     # here we are sure to have a single fetch_head
1069     return fetch_head