chiark / gitweb /
Handle refresh of changed files with non-ASCII names
[stgit] / stgit / git.py
1 """Python GIT interface
2 """
3
4 __copyright__ = """
5 Copyright (C) 2005, Catalin Marinas <catalin.marinas@gmail.com>
6
7 This program is free software; you can redistribute it and/or modify
8 it under the terms of the GNU General Public License version 2 as
9 published by the Free Software Foundation.
10
11 This program is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14 GNU General Public License for more details.
15
16 You should have received a copy of the GNU General Public License
17 along with this program; if not, write to the Free Software
18 Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
19 """
20
21 import sys, os, re, gitmergeonefile
22 from shutil import copyfile
23
24 from stgit.exception import *
25 from stgit import basedir
26 from stgit.utils import *
27 from stgit.out import *
28 from stgit.run import *
29 from stgit.config import config
30
31 # git exception class
32 class GitException(StgException):
33     pass
34
35 # When a subprocess has a problem, we want the exception to be a
36 # subclass of GitException.
37 class GitRunException(GitException):
38     pass
39 class GRun(Run):
40     exc = GitRunException
41     def __init__(self, *cmd):
42         """Initialise the Run object and insert the 'git' command name.
43         """
44         Run.__init__(self, 'git', *cmd)
45
46
47 #
48 # Classes
49 #
50
51 class Person:
52     """An author, committer, etc."""
53     def __init__(self, name = None, email = None, date = '',
54                  desc = None):
55         self.name = self.email = self.date = None
56         if name or email or date:
57             assert not desc
58             self.name = name
59             self.email = email
60             self.date = date
61         elif desc:
62             assert not (name or email or date)
63             def parse_desc(s):
64                 m = re.match(r'^(.+)<(.+)>(.*)$', s)
65                 assert m
66                 return [x.strip() or None for x in m.groups()]
67             self.name, self.email, self.date = parse_desc(desc)
68     def set_name(self, val):
69         if val:
70             self.name = val
71     def set_email(self, val):
72         if val:
73             self.email = val
74     def set_date(self, val):
75         if val:
76             self.date = val
77     def __str__(self):
78         if self.name and self.email:
79             return '%s <%s>' % (self.name, self.email)
80         else:
81             raise GitException, 'not enough identity data'
82
83 class Commit:
84     """Handle the commit objects
85     """
86     def __init__(self, id_hash):
87         self.__id_hash = id_hash
88
89         lines = GRun('cat-file', 'commit', id_hash).output_lines()
90         for i in range(len(lines)):
91             line = lines[i]
92             if not line:
93                 break # we've seen all the header fields
94             key, val = line.split(' ', 1)
95             if key == 'tree':
96                 self.__tree = val
97             elif key == 'author':
98                 self.__author = val
99             elif key == 'committer':
100                 self.__committer = val
101             else:
102                 pass # ignore other headers
103         self.__log = '\n'.join(lines[i+1:])
104
105     def get_id_hash(self):
106         return self.__id_hash
107
108     def get_tree(self):
109         return self.__tree
110
111     def get_parent(self):
112         parents = self.get_parents()
113         if parents:
114             return parents[0]
115         else:
116             return None
117
118     def get_parents(self):
119         return GRun('rev-list', '--parents', '--max-count=1', self.__id_hash
120                     ).output_one_line().split()[1:]
121
122     def get_author(self):
123         return self.__author
124
125     def get_committer(self):
126         return self.__committer
127
128     def get_log(self):
129         return self.__log
130
131     def __str__(self):
132         return self.get_id_hash()
133
134 # dictionary of Commit objects, used to avoid multiple calls to git
135 __commits = dict()
136
137 #
138 # Functions
139 #
140
141 def get_commit(id_hash):
142     """Commit objects factory. Save/look-up them in the __commits
143     dictionary
144     """
145     global __commits
146
147     if id_hash in __commits:
148         return __commits[id_hash]
149     else:
150         commit = Commit(id_hash)
151         __commits[id_hash] = commit
152         return commit
153
154 def get_conflicts():
155     """Return the list of file conflicts
156     """
157     conflicts_file = os.path.join(basedir.get(), 'conflicts')
158     if os.path.isfile(conflicts_file):
159         f = file(conflicts_file)
160         names = [line.strip() for line in f.readlines()]
161         f.close()
162         return names
163     else:
164         return None
165
166 def exclude_files():
167     files = [os.path.join(basedir.get(), 'info', 'exclude')]
168     user_exclude = config.get('core.excludesfile')
169     if user_exclude:
170         files.append(user_exclude)
171     return files
172
173 def ls_files(files, tree = 'HEAD', full_name = True):
174     """Return the files known to GIT or raise an error otherwise. It also
175     converts the file to the full path relative the the .git directory.
176     """
177     if not files:
178         return []
179
180     args = []
181     if tree:
182         args.append('--with-tree=%s' % tree)
183     if full_name:
184         args.append('--full-name')
185     args.append('--')
186     args.extend(files)
187     try:
188         return GRun('ls-files', '--error-unmatch', *args).output_lines()
189     except GitRunException:
190         # just hide the details of the 'git ls-files' command we use
191         raise GitException, \
192             'Some of the given paths are either missing or not known to GIT'
193
194 def parse_git_ls(output):
195     t = None
196     for line in output.split('\0'):
197         if not line:
198             # There's a zero byte at the end of the output, which
199             # gives us an empty string as the last "line".
200             continue
201         if t == None:
202             mode_a, mode_b, sha1_a, sha1_b, t = line.split(' ')
203         else:
204             yield (t, line)
205             t = None
206
207 def tree_status(files = None, tree_id = 'HEAD', unknown = False,
208                   noexclude = True, verbose = False, diff_flags = []):
209     """Get the status of all changed files, or of a selected set of
210     files. Returns a list of pairs - (status, filename).
211
212     If 'not files', it will check all files, and optionally all
213     unknown files.  If 'files' is a list, it will only check the files
214     in the list.
215     """
216     assert not files or not unknown
217
218     if verbose:
219         out.start('Checking for changes in the working directory')
220
221     refresh_index()
222
223     if files is None:
224         files = []
225     cache_files = []
226
227     # unknown files
228     if unknown:
229         cmd = ['ls-files', '-z', '--others', '--directory',
230                '--no-empty-directory']
231         if not noexclude:
232             cmd += ['--exclude=%s' % s for s in
233                     ['*.[ao]', '*.pyc', '.*', '*~', '#*', 'TAGS', 'tags']]
234             cmd += ['--exclude-per-directory=.gitignore']
235             cmd += ['--exclude-from=%s' % fn
236                     for fn in exclude_files()
237                     if os.path.exists(fn)]
238
239         lines = GRun(*cmd).raw_output().split('\0')
240         cache_files += [('?', line) for line in lines if line]
241
242     # conflicted files
243     conflicts = get_conflicts()
244     if not conflicts:
245         conflicts = []
246     cache_files += [('C', filename) for filename in conflicts
247                     if not files or filename in files]
248     reported_files = set(conflicts)
249     files_left = [f for f in files if f not in reported_files]
250
251     # files in the index. Only execute this code if no files were
252     # specified when calling the function (i.e. report all files) or
253     # files were specified but already found in the previous step
254     if not files or files_left:
255         args = diff_flags + [tree_id]
256         if files_left:
257             args += ['--'] + files_left
258         for t, fn in parse_git_ls(GRun('diff-index', '-z', *args).raw_output()):
259             # the condition is needed in case files is emtpy and
260             # diff-index lists those already reported
261             if not fn in reported_files:
262                 cache_files.append((t, fn))
263                 reported_files.add(fn)
264         files_left = [f for f in files if f not in reported_files]
265
266     # files in the index but changed on (or removed from) disk. Only
267     # execute this code if no files were specified when calling the
268     # function (i.e. report all files) or files were specified but
269     # already found in the previous step
270     if not files or files_left:
271         args = list(diff_flags)
272         if files_left:
273             args += ['--'] + files_left
274         for t, fn in parse_git_ls(GRun('diff-files', '-z', *args).raw_output()):
275             # the condition is needed in case files is empty and
276             # diff-files lists those already reported
277             if not fn in reported_files:
278                 cache_files.append((t, fn))
279                 reported_files.add(fn)
280
281     if verbose:
282         out.done()
283
284     return cache_files
285
286 def local_changes(verbose = True):
287     """Return true if there are local changes in the tree
288     """
289     return len(tree_status(verbose = verbose)) != 0
290
291 def get_heads():
292     heads = []
293     hr = re.compile(r'^[0-9a-f]{40} refs/heads/(.+)$')
294     for line in GRun('show-ref', '--heads').output_lines():
295         m = hr.match(line)
296         heads.append(m.group(1))
297     return heads
298
299 # HEAD value cached
300 __head = None
301
302 def get_head():
303     """Verifies the HEAD and returns the SHA1 id that represents it
304     """
305     global __head
306
307     if not __head:
308         __head = rev_parse('HEAD')
309     return __head
310
311 class DetachedHeadException(GitException):
312     def __init__(self):
313         GitException.__init__(self, 'Not on any branch')
314
315 def get_head_file():
316     """Return the name of the file pointed to by the HEAD symref.
317     Throw an exception if HEAD is detached."""
318     try:
319         return strip_prefix(
320             'refs/heads/', GRun('symbolic-ref', '-q', 'HEAD'
321                                 ).output_one_line())
322     except GitRunException:
323         raise DetachedHeadException()
324
325 def set_head_file(ref):
326     """Resets HEAD to point to a new ref
327     """
328     # head cache flushing is needed since we might have a different value
329     # in the new head
330     __clear_head_cache()
331     try:
332         GRun('symbolic-ref', 'HEAD', 'refs/heads/%s' % ref).run()
333     except GitRunException:
334         raise GitException, 'Could not set head to "%s"' % ref
335
336 def set_ref(ref, val):
337     """Point ref at a new commit object."""
338     try:
339         GRun('update-ref', ref, val).run()
340     except GitRunException:
341         raise GitException, 'Could not update %s to "%s".' % (ref, val)
342
343 def set_branch(branch, val):
344     set_ref('refs/heads/%s' % branch, val)
345
346 def __set_head(val):
347     """Sets the HEAD value
348     """
349     global __head
350
351     if not __head or __head != val:
352         set_ref('HEAD', val)
353         __head = val
354
355     # only allow SHA1 hashes
356     assert(len(__head) == 40)
357
358 def __clear_head_cache():
359     """Sets the __head to None so that a re-read is forced
360     """
361     global __head
362
363     __head = None
364
365 def refresh_index():
366     """Refresh index with stat() information from the working directory.
367     """
368     GRun('update-index', '-q', '--unmerged', '--refresh').run()
369
370 def rev_parse(git_id):
371     """Parse the string and return a verified SHA1 id
372     """
373     try:
374         return GRun('rev-parse', '--verify', git_id
375                     ).discard_stderr().output_one_line()
376     except GitRunException:
377         raise GitException, 'Unknown revision: %s' % git_id
378
379 def ref_exists(ref):
380     try:
381         rev_parse(ref)
382         return True
383     except GitException:
384         return False
385
386 def branch_exists(branch):
387     return ref_exists('refs/heads/%s' % branch)
388
389 def create_branch(new_branch, tree_id = None):
390     """Create a new branch in the git repository
391     """
392     if branch_exists(new_branch):
393         raise GitException, 'Branch "%s" already exists' % new_branch
394
395     current_head_file = get_head_file()
396     current_head = get_head()
397     set_head_file(new_branch)
398     __set_head(current_head)
399
400     # a checkout isn't needed if new branch points to the current head
401     if tree_id:
402         try:
403             switch(tree_id)
404         except GitException:
405             # Tree switching failed. Revert the head file
406             set_head_file(current_head_file)
407             delete_branch(new_branch)
408             raise
409
410     if os.path.isfile(os.path.join(basedir.get(), 'MERGE_HEAD')):
411         os.remove(os.path.join(basedir.get(), 'MERGE_HEAD'))
412
413 def switch_branch(new_branch):
414     """Switch to a git branch
415     """
416     global __head
417
418     if not branch_exists(new_branch):
419         raise GitException, 'Branch "%s" does not exist' % new_branch
420
421     tree_id = rev_parse('refs/heads/%s^{commit}' % new_branch)
422     if tree_id != get_head():
423         refresh_index()
424         try:
425             GRun('read-tree', '-u', '-m', get_head(), tree_id).run()
426         except GitRunException:
427             raise GitException, 'read-tree failed (local changes maybe?)'
428         __head = tree_id
429     set_head_file(new_branch)
430
431     if os.path.isfile(os.path.join(basedir.get(), 'MERGE_HEAD')):
432         os.remove(os.path.join(basedir.get(), 'MERGE_HEAD'))
433
434 def delete_ref(ref):
435     if not ref_exists(ref):
436         raise GitException, '%s does not exist' % ref
437     sha1 = GRun('show-ref', '-s', ref).output_one_line()
438     try:
439         GRun('update-ref', '-d', ref, sha1).run()
440     except GitRunException:
441         raise GitException, 'Failed to delete ref %s' % ref
442
443 def delete_branch(name):
444     delete_ref('refs/heads/%s' % name)
445
446 def rename_ref(from_ref, to_ref):
447     if not ref_exists(from_ref):
448         raise GitException, '"%s" does not exist' % from_ref
449     if ref_exists(to_ref):
450         raise GitException, '"%s" already exists' % to_ref
451
452     sha1 = GRun('show-ref', '-s', from_ref).output_one_line()
453     try:
454         GRun('update-ref', to_ref, sha1, '0'*40).run()
455     except GitRunException:
456         raise GitException, 'Failed to create new ref %s' % to_ref
457     try:
458         GRun('update-ref', '-d', from_ref, sha1).run()
459     except GitRunException:
460         raise GitException, 'Failed to delete ref %s' % from_ref
461
462 def rename_branch(from_name, to_name):
463     """Rename a git branch."""
464     rename_ref('refs/heads/%s' % from_name, 'refs/heads/%s' % to_name)
465     try:
466         if get_head_file() == from_name:
467             set_head_file(to_name)
468     except DetachedHeadException:
469         pass # detached HEAD, so the renamee can't be the current branch
470     reflog_dir = os.path.join(basedir.get(), 'logs', 'refs', 'heads')
471     if os.path.exists(reflog_dir) \
472            and os.path.exists(os.path.join(reflog_dir, from_name)):
473         rename(reflog_dir, from_name, to_name)
474
475 def add(names):
476     """Add the files or recursively add the directory contents
477     """
478     # generate the file list
479     files = []
480     for i in names:
481         if not os.path.exists(i):
482             raise GitException, 'Unknown file or directory: %s' % i
483
484         if os.path.isdir(i):
485             # recursive search. We only add files
486             for root, dirs, local_files in os.walk(i):
487                 for name in [os.path.join(root, f) for f in local_files]:
488                     if os.path.isfile(name):
489                         files.append(os.path.normpath(name))
490         elif os.path.isfile(i):
491             files.append(os.path.normpath(i))
492         else:
493             raise GitException, '%s is not a file or directory' % i
494
495     if files:
496         try:
497             GRun('update-index', '--add', '--').xargs(files)
498         except GitRunException:
499             raise GitException, 'Unable to add file'
500
501 def __copy_single(source, target, target2=''):
502     """Copy file or dir named 'source' to name target+target2"""
503
504     # "source" (file or dir) must match one or more git-controlled file
505     realfiles = GRun('ls-files', source).output_lines()
506     if len(realfiles) == 0:
507         raise GitException, '"%s" matches no git-controled files' % source
508
509     if os.path.isdir(source):
510         # physically copy the files, and record them to add them in one run
511         newfiles = []
512         re_string='^'+source+'/(.*)$'
513         prefix_regexp = re.compile(re_string)
514         for f in [f.strip() for f in realfiles]:
515             m = prefix_regexp.match(f)
516             if not m:
517                 raise Exception, '"%s" does not match "%s"' % (f, re_string)
518             newname = target+target2+'/'+m.group(1)
519             if not os.path.exists(os.path.dirname(newname)):
520                 os.makedirs(os.path.dirname(newname))
521             copyfile(f, newname)
522             newfiles.append(newname)
523
524         add(newfiles)
525     else: # files, symlinks, ...
526         newname = target+target2
527         copyfile(source, newname)
528         add([newname])
529
530
531 def copy(filespecs, target):
532     if os.path.isdir(target):
533         # target is a directory: copy each entry on the command line,
534         # with the same name, into the target
535         target = target.rstrip('/')
536         
537         # first, check that none of the children of the target
538         # matching the command line aleady exist
539         for filespec in filespecs:
540             entry = target+ '/' + os.path.basename(filespec.rstrip('/'))
541             if os.path.exists(entry):
542                 raise GitException, 'Target "%s" already exists' % entry
543         
544         for filespec in filespecs:
545             filespec = filespec.rstrip('/')
546             basename = '/' + os.path.basename(filespec)
547             __copy_single(filespec, target, basename)
548
549     elif os.path.exists(target):
550         raise GitException, 'Target "%s" exists but is not a directory' % target
551     elif len(filespecs) != 1:
552         raise GitException, 'Cannot copy more than one file to non-directory'
553
554     else:
555         # at this point: len(filespecs)==1 and target does not exist
556
557         # check target directory
558         targetdir = os.path.dirname(target)
559         if targetdir != '' and not os.path.isdir(targetdir):
560             raise GitException, 'Target directory "%s" does not exist' % targetdir
561
562         __copy_single(filespecs[0].rstrip('/'), target)
563         
564
565 def rm(files, force = False):
566     """Remove a file from the repository
567     """
568     if not force:
569         for f in files:
570             if os.path.exists(f):
571                 raise GitException, '%s exists. Remove it first' %f
572         if files:
573             GRun('update-index', '--remove', '--').xargs(files)
574     else:
575         if files:
576             GRun('update-index', '--force-remove', '--').xargs(files)
577
578 # Persons caching
579 __user = None
580 __author = None
581 __committer = None
582
583 def user():
584     """Return the user information.
585     """
586     global __user
587     if not __user:
588         name=config.get('user.name')
589         email=config.get('user.email')
590         __user = Person(name, email)
591     return __user;
592
593 def author():
594     """Return the author information.
595     """
596     global __author
597     if not __author:
598         try:
599             # the environment variables take priority over config
600             try:
601                 date = os.environ['GIT_AUTHOR_DATE']
602             except KeyError:
603                 date = ''
604             __author = Person(os.environ['GIT_AUTHOR_NAME'],
605                               os.environ['GIT_AUTHOR_EMAIL'],
606                               date)
607         except KeyError:
608             __author = user()
609     return __author
610
611 def committer():
612     """Return the author information.
613     """
614     global __committer
615     if not __committer:
616         try:
617             # the environment variables take priority over config
618             try:
619                 date = os.environ['GIT_COMMITTER_DATE']
620             except KeyError:
621                 date = ''
622             __committer = Person(os.environ['GIT_COMMITTER_NAME'],
623                                  os.environ['GIT_COMMITTER_EMAIL'],
624                                  date)
625         except KeyError:
626             __committer = user()
627     return __committer
628
629 def update_cache(files = None, force = False):
630     """Update the cache information for the given files
631     """
632     cache_files = tree_status(files, verbose = False)
633
634     # everything is up-to-date
635     if len(cache_files) == 0:
636         return False
637
638     # check for unresolved conflicts
639     if not force and [x for x in cache_files
640                       if x[0] not in ['M', 'N', 'A', 'D']]:
641         raise GitException, 'Updating cache failed: unresolved conflicts'
642
643     # update the cache
644     add_files = [x[1] for x in cache_files if x[0] in ['N', 'A']]
645     rm_files =  [x[1] for x in cache_files if x[0] in ['D']]
646     m_files =   [x[1] for x in cache_files if x[0] in ['M']]
647
648     GRun('update-index', '--add', '--').xargs(add_files)
649     GRun('update-index', '--force-remove', '--').xargs(rm_files)
650     GRun('update-index', '--').xargs(m_files)
651
652     return True
653
654 def commit(message, files = None, parents = None, allowempty = False,
655            cache_update = True, tree_id = None, set_head = False,
656            author_name = None, author_email = None, author_date = None,
657            committer_name = None, committer_email = None):
658     """Commit the current tree to repository
659     """
660     if not parents:
661         parents = []
662
663     # Get the tree status
664     if cache_update and parents != []:
665         changes = update_cache(files)
666         if not changes and not allowempty:
667             raise GitException, 'No changes to commit'
668
669     # get the commit message
670     if not message:
671         message = '\n'
672     elif message[-1:] != '\n':
673         message += '\n'
674
675     # write the index to repository
676     if tree_id == None:
677         tree_id = GRun('write-tree').output_one_line()
678         set_head = True
679
680     # the commit
681     env = {}
682     if author_name:
683         env['GIT_AUTHOR_NAME'] = author_name
684     if author_email:
685         env['GIT_AUTHOR_EMAIL'] = author_email
686     if author_date:
687         env['GIT_AUTHOR_DATE'] = author_date
688     if committer_name:
689         env['GIT_COMMITTER_NAME'] = committer_name
690     if committer_email:
691         env['GIT_COMMITTER_EMAIL'] = committer_email
692     commit_id = GRun('commit-tree', tree_id,
693                      *sum([['-p', p] for p in parents], [])
694                      ).env(env).raw_input(message).output_one_line()
695     if set_head:
696         __set_head(commit_id)
697
698     return commit_id
699
700 def apply_diff(rev1, rev2, check_index = True, files = None):
701     """Apply the diff between rev1 and rev2 onto the current
702     index. This function doesn't need to raise an exception since it
703     is only used for fast-pushing a patch. If this operation fails,
704     the pushing would fall back to the three-way merge.
705     """
706     if check_index:
707         index_opt = ['--index']
708     else:
709         index_opt = []
710
711     if not files:
712         files = []
713
714     diff_str = diff(files, rev1, rev2)
715     if diff_str:
716         try:
717             GRun('apply', *index_opt).raw_input(
718                 diff_str).discard_stderr().no_output()
719         except GitRunException:
720             return False
721
722     return True
723
724 def merge(base, head1, head2, recursive = False):
725     """Perform a 3-way merge between base, head1 and head2 into the
726     local tree
727     """
728     refresh_index()
729
730     err_output = None
731     if recursive:
732         # this operation tracks renames but it is slower (used in
733         # general when pushing or picking patches)
734         try:
735             # discard output to mask the verbose prints of the tool
736             GRun('merge-recursive', base, '--', head1, head2
737                  ).discard_output()
738         except GitRunException, ex:
739             err_output = str(ex)
740             pass
741     else:
742         # the fast case where we don't track renames (used when the
743         # distance between base and heads is small, i.e. folding or
744         # synchronising patches)
745         try:
746             GRun('read-tree', '-u', '-m', '--aggressive',
747                  base, head1, head2).run()
748         except GitRunException:
749             raise GitException, 'read-tree failed (local changes maybe?)'
750
751     # check the index for unmerged entries
752     files = {}
753     stages_re = re.compile('^([0-7]+) ([0-9a-f]{40}) ([1-3])\t(.*)$', re.S)
754
755     for line in GRun('ls-files', '--unmerged', '--stage', '-z'
756                      ).raw_output().split('\0'):
757         if not line:
758             continue
759
760         mode, hash, stage, path = stages_re.findall(line)[0]
761
762         if not path in files:
763             files[path] = {}
764             files[path]['1'] = ('', '')
765             files[path]['2'] = ('', '')
766             files[path]['3'] = ('', '')
767
768         files[path][stage] = (mode, hash)
769
770     if err_output and not files:
771         # if no unmerged files, there was probably a different type of
772         # error and we have to abort the merge
773         raise GitException, err_output
774
775     # merge the unmerged files
776     errors = False
777     for path in files:
778         # remove additional files that might be generated for some
779         # newer versions of GIT
780         for suffix in [base, head1, head2]:
781             if not suffix:
782                 continue
783             fname = path + '~' + suffix
784             if os.path.exists(fname):
785                 os.remove(fname)
786
787         stages = files[path]
788         if gitmergeonefile.merge(stages['1'][1], stages['2'][1],
789                                  stages['3'][1], path, stages['1'][0],
790                                  stages['2'][0], stages['3'][0]) != 0:
791             errors = True
792
793     if errors:
794         raise GitException, 'GIT index merging failed (possible conflicts)'
795
796 def diff(files = None, rev1 = 'HEAD', rev2 = None, diff_flags = [],
797          binary = True):
798     """Show the diff between rev1 and rev2
799     """
800     if not files:
801         files = []
802     if binary and '--binary' not in diff_flags:
803         diff_flags = diff_flags + ['--binary']
804
805     if rev1 and rev2:
806         return GRun('diff-tree', '-p',
807                     *(diff_flags + [rev1, rev2, '--'] + files)).raw_output()
808     elif rev1 or rev2:
809         refresh_index()
810         if rev2:
811             return GRun('diff-index', '-p', '-R',
812                         *(diff_flags + [rev2, '--'] + files)).raw_output()
813         else:
814             return GRun('diff-index', '-p',
815                         *(diff_flags + [rev1, '--'] + files)).raw_output()
816     else:
817         return ''
818
819 # TODO: take another parameter representing a diff string as we
820 # usually invoke git.diff() form the calling functions
821 def diffstat(files = None, rev1 = 'HEAD', rev2 = None):
822     """Return the diffstat between rev1 and rev2."""
823     return GRun('apply', '--stat', '--summary'
824                 ).raw_input(diff(files, rev1, rev2)).raw_output()
825
826 def files(rev1, rev2, diff_flags = []):
827     """Return the files modified between rev1 and rev2
828     """
829
830     result = []
831     for line in GRun('diff-tree', *(diff_flags + ['-r', rev1, rev2])
832                      ).output_lines():
833         result.append('%s %s' % tuple(line.split(' ', 4)[-1].split('\t', 1)))
834
835     return '\n'.join(result)
836
837 def barefiles(rev1, rev2):
838     """Return the files modified between rev1 and rev2, without status info
839     """
840
841     result = []
842     for line in GRun('diff-tree', '-r', rev1, rev2).output_lines():
843         result.append(line.split(' ', 4)[-1].split('\t', 1)[-1])
844
845     return '\n'.join(result)
846
847 def pretty_commit(commit_id = 'HEAD', flags = []):
848     """Return a given commit (log + diff)
849     """
850     return GRun('show', *(flags + [commit_id])).raw_output()
851
852 def checkout(files = None, tree_id = None, force = False):
853     """Check out the given or all files
854     """
855     if tree_id:
856         try:
857             GRun('read-tree', '--reset', tree_id).run()
858         except GitRunException:
859             raise GitException, 'Failed "git read-tree" --reset %s' % tree_id
860
861     cmd = ['checkout-index', '-q', '-u']
862     if force:
863         cmd.append('-f')
864     if files:
865         GRun(*(cmd + ['--'])).xargs(files)
866     else:
867         GRun(*(cmd + ['-a'])).run()
868
869 def switch(tree_id, keep = False):
870     """Switch the tree to the given id
871     """
872     if keep:
873         # only update the index while keeping the local changes
874         GRun('read-tree', tree_id).run()
875     else:
876         refresh_index()
877         try:
878             GRun('read-tree', '-u', '-m', get_head(), tree_id).run()
879         except GitRunException:
880             raise GitException, 'read-tree failed (local changes maybe?)'
881
882     __set_head(tree_id)
883
884 def reset(files = None, tree_id = None, check_out = True):
885     """Revert the tree changes relative to the given tree_id. It removes
886     any local changes
887     """
888     if not tree_id:
889         tree_id = get_head()
890
891     if check_out:
892         cache_files = tree_status(files, tree_id)
893         # files which were added but need to be removed
894         rm_files =  [x[1] for x in cache_files if x[0] in ['A']]
895
896         checkout(files, tree_id, True)
897         # checkout doesn't remove files
898         map(os.remove, rm_files)
899
900     # if the reset refers to the whole tree, switch the HEAD as well
901     if not files:
902         __set_head(tree_id)
903
904 def fetch(repository = 'origin', refspec = None):
905     """Fetches changes from the remote repository, using 'git fetch'
906     by default.
907     """
908     # we update the HEAD
909     __clear_head_cache()
910
911     args = [repository]
912     if refspec:
913         args.append(refspec)
914
915     command = config.get('branch.%s.stgit.fetchcmd' % get_head_file()) or \
916               config.get('stgit.fetchcmd')
917     Run(*(command.split() + args)).run()
918
919 def pull(repository = 'origin', refspec = None):
920     """Fetches changes from the remote repository, using 'git pull'
921     by default.
922     """
923     # we update the HEAD
924     __clear_head_cache()
925
926     args = [repository]
927     if refspec:
928         args.append(refspec)
929
930     command = config.get('branch.%s.stgit.pullcmd' % get_head_file()) or \
931               config.get('stgit.pullcmd')
932     Run(*(command.split() + args)).run()
933
934 def rebase(tree_id = None):
935     """Rebase the current tree to the give tree_id. The tree_id
936     argument may be something other than a GIT id if an external
937     command is invoked.
938     """
939     command = config.get('branch.%s.stgit.rebasecmd' % get_head_file()) \
940                 or config.get('stgit.rebasecmd')
941     if tree_id:
942         args = [tree_id]
943     elif command:
944         args = []
945     else:
946         raise GitException, 'Default rebasing requires a commit id'
947     if command:
948         # clear the HEAD cache as the custom rebase command will update it
949         __clear_head_cache()
950         Run(*(command.split() + args)).run()
951     else:
952         # default rebasing
953         reset(tree_id = tree_id)
954
955 def repack():
956     """Repack all objects into a single pack
957     """
958     GRun('repack', '-a', '-d', '-f').run()
959
960 def apply_patch(filename = None, diff = None, base = None,
961                 fail_dump = True):
962     """Apply a patch onto the current or given index. There must not
963     be any local changes in the tree, otherwise the command fails
964     """
965     if diff is None:
966         if filename:
967             f = file(filename)
968         else:
969             f = sys.stdin
970         diff = f.read()
971         if filename:
972             f.close()
973
974     if base:
975         orig_head = get_head()
976         switch(base)
977     else:
978         refresh_index()
979
980     try:
981         GRun('apply', '--index').raw_input(diff).no_output()
982     except GitRunException:
983         if base:
984             switch(orig_head)
985         if fail_dump:
986             # write the failed diff to a file
987             f = file('.stgit-failed.patch', 'w+')
988             f.write(diff)
989             f.close()
990             out.warn('Diff written to the .stgit-failed.patch file')
991
992         raise
993
994     if base:
995         top = commit(message = 'temporary commit used for applying a patch',
996                      parents = [base])
997         switch(orig_head)
998         merge(base, orig_head, top)
999
1000 def clone(repository, local_dir):
1001     """Clone a remote repository. At the moment, just use the
1002     'git clone' script
1003     """
1004     GRun('clone', repository, local_dir).run()
1005
1006 def modifying_revs(files, base_rev, head_rev):
1007     """Return the revisions from the list modifying the given files."""
1008     return GRun('rev-list', '%s..%s' % (base_rev, head_rev), '--', *files
1009                 ).output_lines()
1010
1011 def refspec_localpart(refspec):
1012     m = re.match('^[^:]*:([^:]*)$', refspec)
1013     if m:
1014         return m.group(1)
1015     else:
1016         raise GitException, 'Cannot parse refspec "%s"' % line
1017
1018 def refspec_remotepart(refspec):
1019     m = re.match('^([^:]*):[^:]*$', refspec)
1020     if m:
1021         return m.group(1)
1022     else:
1023         raise GitException, 'Cannot parse refspec "%s"' % line
1024
1025 def __remotes_from_config():
1026     return config.sections_matching(r'remote\.(.*)\.url')
1027
1028 def __remotes_from_dir(dir):
1029     d = os.path.join(basedir.get(), dir)
1030     if os.path.exists(d):
1031         return os.listdir(d)
1032     else:
1033         return []
1034
1035 def remotes_list():
1036     """Return the list of remotes in the repository
1037     """
1038     return (set(__remotes_from_config())
1039             | set(__remotes_from_dir('remotes'))
1040             | set(__remotes_from_dir('branches')))
1041
1042 def remotes_local_branches(remote):
1043     """Returns the list of local branches fetched from given remote
1044     """
1045
1046     branches = []
1047     if remote in __remotes_from_config():
1048         for line in config.getall('remote.%s.fetch' % remote):
1049             branches.append(refspec_localpart(line))
1050     elif remote in __remotes_from_dir('remotes'):
1051         stream = open(os.path.join(basedir.get(), 'remotes', remote), 'r')
1052         for line in stream:
1053             # Only consider Pull lines
1054             m = re.match('^Pull: (.*)\n$', line)
1055             if m:
1056                 branches.append(refspec_localpart(m.group(1)))
1057         stream.close()
1058     elif remote in __remotes_from_dir('branches'):
1059         # old-style branches only declare one branch
1060         branches.append('refs/heads/'+remote);
1061     else:
1062         raise GitException, 'Unknown remote "%s"' % remote
1063
1064     return branches
1065
1066 def identify_remote(branchname):
1067     """Return the name for the remote to pull the given branchname
1068     from, or None if we believe it is a local branch.
1069     """
1070
1071     for remote in remotes_list():
1072         if branchname in remotes_local_branches(remote):
1073             return remote
1074
1075     # if we get here we've found nothing, the branch is a local one
1076     return None
1077
1078 def fetch_head():
1079     """Return the git id for the tip of the parent branch as left by
1080     'git fetch'.
1081     """
1082
1083     fetch_head=None
1084     stream = open(os.path.join(basedir.get(), 'FETCH_HEAD'), "r")
1085     for line in stream:
1086         # Only consider lines not tagged not-for-merge
1087         m = re.match('^([^\t]*)\t\t', line)
1088         if m:
1089             if fetch_head:
1090                 raise GitException, 'StGit does not support multiple FETCH_HEAD'
1091             else:
1092                 fetch_head=m.group(1)
1093     stream.close()
1094
1095     if not fetch_head:
1096         out.warn('No for-merge remote head found in FETCH_HEAD')
1097
1098     # here we are sure to have a single fetch_head
1099     return fetch_head
1100
1101 def all_refs():
1102     """Return a list of all refs in the current repository.
1103     """
1104
1105     return [line.split()[1] for line in GRun('show-ref').output_lines()]