chiark / gitweb /
8c637d5530a6ab285a2fb1135200d24f225319b9
[stgit] / stgit / git.py
1 """Python GIT interface
2 """
3
4 __copyright__ = """
5 Copyright (C) 2005, Catalin Marinas <catalin.marinas@gmail.com>
6
7 This program is free software; you can redistribute it and/or modify
8 it under the terms of the GNU General Public License version 2 as
9 published by the Free Software Foundation.
10
11 This program is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14 GNU General Public License for more details.
15
16 You should have received a copy of the GNU General Public License
17 along with this program; if not, write to the Free Software
18 Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
19 """
20
21 import sys, os, re, gitmergeonefile
22 from shutil import copyfile
23
24 from stgit.exception import *
25 from stgit import basedir
26 from stgit.utils import *
27 from stgit.out import *
28 from stgit.run import *
29 from stgit.config import config
30
31 # git exception class
32 class GitException(StgException):
33     pass
34
35 # When a subprocess has a problem, we want the exception to be a
36 # subclass of GitException.
37 class GitRunException(GitException):
38     pass
39 class GRun(Run):
40     exc = GitRunException
41     def __init__(self, *cmd):
42         """Initialise the Run object and insert the 'git' command name.
43         """
44         Run.__init__(self, 'git', *cmd)
45
46
47 #
48 # Classes
49 #
50
51 class Person:
52     """An author, committer, etc."""
53     def __init__(self, name = None, email = None, date = '',
54                  desc = None):
55         self.name = self.email = self.date = None
56         if name or email or date:
57             assert not desc
58             self.name = name
59             self.email = email
60             self.date = date
61         elif desc:
62             assert not (name or email or date)
63             def parse_desc(s):
64                 m = re.match(r'^(.+)<(.+)>(.*)$', s)
65                 assert m
66                 return [x.strip() or None for x in m.groups()]
67             self.name, self.email, self.date = parse_desc(desc)
68     def set_name(self, val):
69         if val:
70             self.name = val
71     def set_email(self, val):
72         if val:
73             self.email = val
74     def set_date(self, val):
75         if val:
76             self.date = val
77     def __str__(self):
78         if self.name and self.email:
79             return '%s <%s>' % (self.name, self.email)
80         else:
81             raise GitException, 'not enough identity data'
82
83 class Commit:
84     """Handle the commit objects
85     """
86     def __init__(self, id_hash):
87         self.__id_hash = id_hash
88
89         lines = GRun('cat-file', 'commit', id_hash).output_lines()
90         for i in range(len(lines)):
91             line = lines[i]
92             if not line:
93                 break # we've seen all the header fields
94             key, val = line.split(' ', 1)
95             if key == 'tree':
96                 self.__tree = val
97             elif key == 'author':
98                 self.__author = val
99             elif key == 'committer':
100                 self.__committer = val
101             else:
102                 pass # ignore other headers
103         self.__log = '\n'.join(lines[i+1:])
104
105     def get_id_hash(self):
106         return self.__id_hash
107
108     def get_tree(self):
109         return self.__tree
110
111     def get_parent(self):
112         parents = self.get_parents()
113         if parents:
114             return parents[0]
115         else:
116             return None
117
118     def get_parents(self):
119         return GRun('rev-list', '--parents', '--max-count=1', self.__id_hash
120                     ).output_one_line().split()[1:]
121
122     def get_author(self):
123         return self.__author
124
125     def get_committer(self):
126         return self.__committer
127
128     def get_log(self):
129         return self.__log
130
131     def __str__(self):
132         return self.get_id_hash()
133
134 # dictionary of Commit objects, used to avoid multiple calls to git
135 __commits = dict()
136
137 #
138 # Functions
139 #
140
141 def get_commit(id_hash):
142     """Commit objects factory. Save/look-up them in the __commits
143     dictionary
144     """
145     global __commits
146
147     if id_hash in __commits:
148         return __commits[id_hash]
149     else:
150         commit = Commit(id_hash)
151         __commits[id_hash] = commit
152         return commit
153
154 def get_conflicts():
155     """Return the list of file conflicts
156     """
157     conflicts_file = os.path.join(basedir.get(), 'conflicts')
158     if os.path.isfile(conflicts_file):
159         f = file(conflicts_file)
160         names = [line.strip() for line in f.readlines()]
161         f.close()
162         return names
163     else:
164         return None
165
166 def exclude_files():
167     files = [os.path.join(basedir.get(), 'info', 'exclude')]
168     user_exclude = config.get('core.excludesfile')
169     if user_exclude:
170         files.append(user_exclude)
171     return files
172
173 def ls_files(files, tree = 'HEAD', full_name = True):
174     """Return the files known to GIT or raise an error otherwise. It also
175     converts the file to the full path relative the the .git directory.
176     """
177     if not files:
178         return []
179
180     args = []
181     if tree:
182         args.append('--with-tree=%s' % tree)
183     if full_name:
184         args.append('--full-name')
185     args.append('--')
186     args.extend(files)
187     try:
188         return GRun('ls-files', '--error-unmatch', *args).output_lines()
189     except GitRunException:
190         # just hide the details of the 'git ls-files' command we use
191         raise GitException, \
192             'Some of the given paths are either missing or not known to GIT'
193
194 def tree_status(files = None, tree_id = 'HEAD', unknown = False,
195                   noexclude = True, verbose = False, diff_flags = []):
196     """Get the status of all changed files, or of a selected set of
197     files. Returns a list of pairs - (status, filename).
198
199     If 'not files', it will check all files, and optionally all
200     unknown files.  If 'files' is a list, it will only check the files
201     in the list.
202     """
203     assert not files or not unknown
204
205     if verbose:
206         out.start('Checking for changes in the working directory')
207
208     refresh_index()
209
210     if files is None:
211         files = []
212     cache_files = []
213
214     # unknown files
215     if unknown:
216         cmd = ['ls-files', '-z', '--others', '--directory',
217                '--no-empty-directory']
218         if not noexclude:
219             cmd += ['--exclude=%s' % s for s in
220                     ['*.[ao]', '*.pyc', '.*', '*~', '#*', 'TAGS', 'tags']]
221             cmd += ['--exclude-per-directory=.gitignore']
222             cmd += ['--exclude-from=%s' % fn
223                     for fn in exclude_files()
224                     if os.path.exists(fn)]
225
226         lines = GRun(*cmd).raw_output().split('\0')
227         cache_files += [('?', line) for line in lines if line]
228
229     # conflicted files
230     conflicts = get_conflicts()
231     if not conflicts:
232         conflicts = []
233     cache_files += [('C', filename) for filename in conflicts
234                     if not files or filename in files]
235     reported_files = set(conflicts)
236     files_left = [f for f in files if f not in reported_files]
237
238     # files in the index. Only execute this code if no files were
239     # specified when calling the function (i.e. report all files) or
240     # files were specified but already found in the previous step
241     if not files or files_left:
242         args = diff_flags + [tree_id]
243         if files_left:
244             args += ['--'] + files_left
245         t = None
246         for line in GRun('diff-index', '-z', *args).raw_output().split('\0'):
247             if not line:
248                 # There's a zero byte at the end of the output, which
249                 # gives us an empty string as the last "line".
250                 continue
251             if t == None:
252                 mode_a, mode_b, sha1_a, sha1_b, t = line.split(' ')
253             else:
254                 # the condition is needed in case files is emtpy and
255                 # diff-index lists those already reported
256                 if not line in reported_files:
257                     cache_files.append((t, line))
258                     reported_files.add(line)
259                 t = None
260         files_left = [f for f in files if f not in reported_files]
261
262     # files in the index but changed on (or removed from) disk. Only
263     # execute this code if no files were specified when calling the
264     # function (i.e. report all files) or files were specified but
265     # already found in the previous step
266     if not files or files_left:
267         args = list(diff_flags)
268         if files_left:
269             args += ['--'] + files_left
270         for line in GRun('diff-files', *args).output_lines():
271             fs = tuple(line.rstrip().split(' ',4)[-1].split('\t',1))
272             # the condition is needed in case files is empty and
273             # diff-files lists those already reported
274             if fs[1] not in reported_files:
275                 cache_files.append(fs)
276                 reported_files.add(fs[1])
277
278     if verbose:
279         out.done()
280
281     return cache_files
282
283 def local_changes(verbose = True):
284     """Return true if there are local changes in the tree
285     """
286     return len(tree_status(verbose = verbose)) != 0
287
288 def get_heads():
289     heads = []
290     hr = re.compile(r'^[0-9a-f]{40} refs/heads/(.+)$')
291     for line in GRun('show-ref', '--heads').output_lines():
292         m = hr.match(line)
293         heads.append(m.group(1))
294     return heads
295
296 # HEAD value cached
297 __head = None
298
299 def get_head():
300     """Verifies the HEAD and returns the SHA1 id that represents it
301     """
302     global __head
303
304     if not __head:
305         __head = rev_parse('HEAD')
306     return __head
307
308 class DetachedHeadException(GitException):
309     def __init__(self):
310         GitException.__init__(self, 'Not on any branch')
311
312 def get_head_file():
313     """Return the name of the file pointed to by the HEAD symref.
314     Throw an exception if HEAD is detached."""
315     try:
316         return strip_prefix(
317             'refs/heads/', GRun('symbolic-ref', '-q', 'HEAD'
318                                 ).output_one_line())
319     except GitRunException:
320         raise DetachedHeadException()
321
322 def set_head_file(ref):
323     """Resets HEAD to point to a new ref
324     """
325     # head cache flushing is needed since we might have a different value
326     # in the new head
327     __clear_head_cache()
328     try:
329         GRun('symbolic-ref', 'HEAD', 'refs/heads/%s' % ref).run()
330     except GitRunException:
331         raise GitException, 'Could not set head to "%s"' % ref
332
333 def set_ref(ref, val):
334     """Point ref at a new commit object."""
335     try:
336         GRun('update-ref', ref, val).run()
337     except GitRunException:
338         raise GitException, 'Could not update %s to "%s".' % (ref, val)
339
340 def set_branch(branch, val):
341     set_ref('refs/heads/%s' % branch, val)
342
343 def __set_head(val):
344     """Sets the HEAD value
345     """
346     global __head
347
348     if not __head or __head != val:
349         set_ref('HEAD', val)
350         __head = val
351
352     # only allow SHA1 hashes
353     assert(len(__head) == 40)
354
355 def __clear_head_cache():
356     """Sets the __head to None so that a re-read is forced
357     """
358     global __head
359
360     __head = None
361
362 def refresh_index():
363     """Refresh index with stat() information from the working directory.
364     """
365     GRun('update-index', '-q', '--unmerged', '--refresh').run()
366
367 def rev_parse(git_id):
368     """Parse the string and return a verified SHA1 id
369     """
370     try:
371         return GRun('rev-parse', '--verify', git_id
372                     ).discard_stderr().output_one_line()
373     except GitRunException:
374         raise GitException, 'Unknown revision: %s' % git_id
375
376 def ref_exists(ref):
377     try:
378         rev_parse(ref)
379         return True
380     except GitException:
381         return False
382
383 def branch_exists(branch):
384     return ref_exists('refs/heads/%s' % branch)
385
386 def create_branch(new_branch, tree_id = None):
387     """Create a new branch in the git repository
388     """
389     if branch_exists(new_branch):
390         raise GitException, 'Branch "%s" already exists' % new_branch
391
392     current_head_file = get_head_file()
393     current_head = get_head()
394     set_head_file(new_branch)
395     __set_head(current_head)
396
397     # a checkout isn't needed if new branch points to the current head
398     if tree_id:
399         try:
400             switch(tree_id)
401         except GitException:
402             # Tree switching failed. Revert the head file
403             set_head_file(current_head_file)
404             delete_branch(new_branch)
405             raise
406
407     if os.path.isfile(os.path.join(basedir.get(), 'MERGE_HEAD')):
408         os.remove(os.path.join(basedir.get(), 'MERGE_HEAD'))
409
410 def switch_branch(new_branch):
411     """Switch to a git branch
412     """
413     global __head
414
415     if not branch_exists(new_branch):
416         raise GitException, 'Branch "%s" does not exist' % new_branch
417
418     tree_id = rev_parse('refs/heads/%s^{commit}' % new_branch)
419     if tree_id != get_head():
420         refresh_index()
421         try:
422             GRun('read-tree', '-u', '-m', get_head(), tree_id).run()
423         except GitRunException:
424             raise GitException, 'read-tree failed (local changes maybe?)'
425         __head = tree_id
426     set_head_file(new_branch)
427
428     if os.path.isfile(os.path.join(basedir.get(), 'MERGE_HEAD')):
429         os.remove(os.path.join(basedir.get(), 'MERGE_HEAD'))
430
431 def delete_ref(ref):
432     if not ref_exists(ref):
433         raise GitException, '%s does not exist' % ref
434     sha1 = GRun('show-ref', '-s', ref).output_one_line()
435     try:
436         GRun('update-ref', '-d', ref, sha1).run()
437     except GitRunException:
438         raise GitException, 'Failed to delete ref %s' % ref
439
440 def delete_branch(name):
441     delete_ref('refs/heads/%s' % name)
442
443 def rename_ref(from_ref, to_ref):
444     if not ref_exists(from_ref):
445         raise GitException, '"%s" does not exist' % from_ref
446     if ref_exists(to_ref):
447         raise GitException, '"%s" already exists' % to_ref
448
449     sha1 = GRun('show-ref', '-s', from_ref).output_one_line()
450     try:
451         GRun('update-ref', to_ref, sha1, '0'*40).run()
452     except GitRunException:
453         raise GitException, 'Failed to create new ref %s' % to_ref
454     try:
455         GRun('update-ref', '-d', from_ref, sha1).run()
456     except GitRunException:
457         raise GitException, 'Failed to delete ref %s' % from_ref
458
459 def rename_branch(from_name, to_name):
460     """Rename a git branch."""
461     rename_ref('refs/heads/%s' % from_name, 'refs/heads/%s' % to_name)
462     try:
463         if get_head_file() == from_name:
464             set_head_file(to_name)
465     except DetachedHeadException:
466         pass # detached HEAD, so the renamee can't be the current branch
467     reflog_dir = os.path.join(basedir.get(), 'logs', 'refs', 'heads')
468     if os.path.exists(reflog_dir) \
469            and os.path.exists(os.path.join(reflog_dir, from_name)):
470         rename(reflog_dir, from_name, to_name)
471
472 def add(names):
473     """Add the files or recursively add the directory contents
474     """
475     # generate the file list
476     files = []
477     for i in names:
478         if not os.path.exists(i):
479             raise GitException, 'Unknown file or directory: %s' % i
480
481         if os.path.isdir(i):
482             # recursive search. We only add files
483             for root, dirs, local_files in os.walk(i):
484                 for name in [os.path.join(root, f) for f in local_files]:
485                     if os.path.isfile(name):
486                         files.append(os.path.normpath(name))
487         elif os.path.isfile(i):
488             files.append(os.path.normpath(i))
489         else:
490             raise GitException, '%s is not a file or directory' % i
491
492     if files:
493         try:
494             GRun('update-index', '--add', '--').xargs(files)
495         except GitRunException:
496             raise GitException, 'Unable to add file'
497
498 def __copy_single(source, target, target2=''):
499     """Copy file or dir named 'source' to name target+target2"""
500
501     # "source" (file or dir) must match one or more git-controlled file
502     realfiles = GRun('ls-files', source).output_lines()
503     if len(realfiles) == 0:
504         raise GitException, '"%s" matches no git-controled files' % source
505
506     if os.path.isdir(source):
507         # physically copy the files, and record them to add them in one run
508         newfiles = []
509         re_string='^'+source+'/(.*)$'
510         prefix_regexp = re.compile(re_string)
511         for f in [f.strip() for f in realfiles]:
512             m = prefix_regexp.match(f)
513             if not m:
514                 raise Exception, '"%s" does not match "%s"' % (f, re_string)
515             newname = target+target2+'/'+m.group(1)
516             if not os.path.exists(os.path.dirname(newname)):
517                 os.makedirs(os.path.dirname(newname))
518             copyfile(f, newname)
519             newfiles.append(newname)
520
521         add(newfiles)
522     else: # files, symlinks, ...
523         newname = target+target2
524         copyfile(source, newname)
525         add([newname])
526
527
528 def copy(filespecs, target):
529     if os.path.isdir(target):
530         # target is a directory: copy each entry on the command line,
531         # with the same name, into the target
532         target = target.rstrip('/')
533         
534         # first, check that none of the children of the target
535         # matching the command line aleady exist
536         for filespec in filespecs:
537             entry = target+ '/' + os.path.basename(filespec.rstrip('/'))
538             if os.path.exists(entry):
539                 raise GitException, 'Target "%s" already exists' % entry
540         
541         for filespec in filespecs:
542             filespec = filespec.rstrip('/')
543             basename = '/' + os.path.basename(filespec)
544             __copy_single(filespec, target, basename)
545
546     elif os.path.exists(target):
547         raise GitException, 'Target "%s" exists but is not a directory' % target
548     elif len(filespecs) != 1:
549         raise GitException, 'Cannot copy more than one file to non-directory'
550
551     else:
552         # at this point: len(filespecs)==1 and target does not exist
553
554         # check target directory
555         targetdir = os.path.dirname(target)
556         if targetdir != '' and not os.path.isdir(targetdir):
557             raise GitException, 'Target directory "%s" does not exist' % targetdir
558
559         __copy_single(filespecs[0].rstrip('/'), target)
560         
561
562 def rm(files, force = False):
563     """Remove a file from the repository
564     """
565     if not force:
566         for f in files:
567             if os.path.exists(f):
568                 raise GitException, '%s exists. Remove it first' %f
569         if files:
570             GRun('update-index', '--remove', '--').xargs(files)
571     else:
572         if files:
573             GRun('update-index', '--force-remove', '--').xargs(files)
574
575 # Persons caching
576 __user = None
577 __author = None
578 __committer = None
579
580 def user():
581     """Return the user information.
582     """
583     global __user
584     if not __user:
585         name=config.get('user.name')
586         email=config.get('user.email')
587         __user = Person(name, email)
588     return __user;
589
590 def author():
591     """Return the author information.
592     """
593     global __author
594     if not __author:
595         try:
596             # the environment variables take priority over config
597             try:
598                 date = os.environ['GIT_AUTHOR_DATE']
599             except KeyError:
600                 date = ''
601             __author = Person(os.environ['GIT_AUTHOR_NAME'],
602                               os.environ['GIT_AUTHOR_EMAIL'],
603                               date)
604         except KeyError:
605             __author = user()
606     return __author
607
608 def committer():
609     """Return the author information.
610     """
611     global __committer
612     if not __committer:
613         try:
614             # the environment variables take priority over config
615             try:
616                 date = os.environ['GIT_COMMITTER_DATE']
617             except KeyError:
618                 date = ''
619             __committer = Person(os.environ['GIT_COMMITTER_NAME'],
620                                  os.environ['GIT_COMMITTER_EMAIL'],
621                                  date)
622         except KeyError:
623             __committer = user()
624     return __committer
625
626 def update_cache(files = None, force = False):
627     """Update the cache information for the given files
628     """
629     cache_files = tree_status(files, verbose = False)
630
631     # everything is up-to-date
632     if len(cache_files) == 0:
633         return False
634
635     # check for unresolved conflicts
636     if not force and [x for x in cache_files
637                       if x[0] not in ['M', 'N', 'A', 'D']]:
638         raise GitException, 'Updating cache failed: unresolved conflicts'
639
640     # update the cache
641     add_files = [x[1] for x in cache_files if x[0] in ['N', 'A']]
642     rm_files =  [x[1] for x in cache_files if x[0] in ['D']]
643     m_files =   [x[1] for x in cache_files if x[0] in ['M']]
644
645     GRun('update-index', '--add', '--').xargs(add_files)
646     GRun('update-index', '--force-remove', '--').xargs(rm_files)
647     GRun('update-index', '--').xargs(m_files)
648
649     return True
650
651 def commit(message, files = None, parents = None, allowempty = False,
652            cache_update = True, tree_id = None, set_head = False,
653            author_name = None, author_email = None, author_date = None,
654            committer_name = None, committer_email = None):
655     """Commit the current tree to repository
656     """
657     if not parents:
658         parents = []
659
660     # Get the tree status
661     if cache_update and parents != []:
662         changes = update_cache(files)
663         if not changes and not allowempty:
664             raise GitException, 'No changes to commit'
665
666     # get the commit message
667     if not message:
668         message = '\n'
669     elif message[-1:] != '\n':
670         message += '\n'
671
672     # write the index to repository
673     if tree_id == None:
674         tree_id = GRun('write-tree').output_one_line()
675         set_head = True
676
677     # the commit
678     env = {}
679     if author_name:
680         env['GIT_AUTHOR_NAME'] = author_name
681     if author_email:
682         env['GIT_AUTHOR_EMAIL'] = author_email
683     if author_date:
684         env['GIT_AUTHOR_DATE'] = author_date
685     if committer_name:
686         env['GIT_COMMITTER_NAME'] = committer_name
687     if committer_email:
688         env['GIT_COMMITTER_EMAIL'] = committer_email
689     commit_id = GRun('commit-tree', tree_id,
690                      *sum([['-p', p] for p in parents], [])
691                      ).env(env).raw_input(message).output_one_line()
692     if set_head:
693         __set_head(commit_id)
694
695     return commit_id
696
697 def apply_diff(rev1, rev2, check_index = True, files = None):
698     """Apply the diff between rev1 and rev2 onto the current
699     index. This function doesn't need to raise an exception since it
700     is only used for fast-pushing a patch. If this operation fails,
701     the pushing would fall back to the three-way merge.
702     """
703     if check_index:
704         index_opt = ['--index']
705     else:
706         index_opt = []
707
708     if not files:
709         files = []
710
711     diff_str = diff(files, rev1, rev2)
712     if diff_str:
713         try:
714             GRun('apply', *index_opt).raw_input(
715                 diff_str).discard_stderr().no_output()
716         except GitRunException:
717             return False
718
719     return True
720
721 def merge(base, head1, head2, recursive = False):
722     """Perform a 3-way merge between base, head1 and head2 into the
723     local tree
724     """
725     refresh_index()
726
727     err_output = None
728     if recursive:
729         # this operation tracks renames but it is slower (used in
730         # general when pushing or picking patches)
731         try:
732             # discard output to mask the verbose prints of the tool
733             GRun('merge-recursive', base, '--', head1, head2
734                  ).discard_output()
735         except GitRunException, ex:
736             err_output = str(ex)
737             pass
738     else:
739         # the fast case where we don't track renames (used when the
740         # distance between base and heads is small, i.e. folding or
741         # synchronising patches)
742         try:
743             GRun('read-tree', '-u', '-m', '--aggressive',
744                  base, head1, head2).run()
745         except GitRunException:
746             raise GitException, 'read-tree failed (local changes maybe?)'
747
748     # check the index for unmerged entries
749     files = {}
750     stages_re = re.compile('^([0-7]+) ([0-9a-f]{40}) ([1-3])\t(.*)$', re.S)
751
752     for line in GRun('ls-files', '--unmerged', '--stage', '-z'
753                      ).raw_output().split('\0'):
754         if not line:
755             continue
756
757         mode, hash, stage, path = stages_re.findall(line)[0]
758
759         if not path in files:
760             files[path] = {}
761             files[path]['1'] = ('', '')
762             files[path]['2'] = ('', '')
763             files[path]['3'] = ('', '')
764
765         files[path][stage] = (mode, hash)
766
767     if err_output and not files:
768         # if no unmerged files, there was probably a different type of
769         # error and we have to abort the merge
770         raise GitException, err_output
771
772     # merge the unmerged files
773     errors = False
774     for path in files:
775         # remove additional files that might be generated for some
776         # newer versions of GIT
777         for suffix in [base, head1, head2]:
778             if not suffix:
779                 continue
780             fname = path + '~' + suffix
781             if os.path.exists(fname):
782                 os.remove(fname)
783
784         stages = files[path]
785         if gitmergeonefile.merge(stages['1'][1], stages['2'][1],
786                                  stages['3'][1], path, stages['1'][0],
787                                  stages['2'][0], stages['3'][0]) != 0:
788             errors = True
789
790     if errors:
791         raise GitException, 'GIT index merging failed (possible conflicts)'
792
793 def diff(files = None, rev1 = 'HEAD', rev2 = None, diff_flags = [],
794          binary = True):
795     """Show the diff between rev1 and rev2
796     """
797     if not files:
798         files = []
799     if binary and '--binary' not in diff_flags:
800         diff_flags = diff_flags + ['--binary']
801
802     if rev1 and rev2:
803         return GRun('diff-tree', '-p',
804                     *(diff_flags + [rev1, rev2, '--'] + files)).raw_output()
805     elif rev1 or rev2:
806         refresh_index()
807         if rev2:
808             return GRun('diff-index', '-p', '-R',
809                         *(diff_flags + [rev2, '--'] + files)).raw_output()
810         else:
811             return GRun('diff-index', '-p',
812                         *(diff_flags + [rev1, '--'] + files)).raw_output()
813     else:
814         return ''
815
816 # TODO: take another parameter representing a diff string as we
817 # usually invoke git.diff() form the calling functions
818 def diffstat(files = None, rev1 = 'HEAD', rev2 = None):
819     """Return the diffstat between rev1 and rev2."""
820     return GRun('apply', '--stat', '--summary'
821                 ).raw_input(diff(files, rev1, rev2)).raw_output()
822
823 def files(rev1, rev2, diff_flags = []):
824     """Return the files modified between rev1 and rev2
825     """
826
827     result = []
828     for line in GRun('diff-tree', *(diff_flags + ['-r', rev1, rev2])
829                      ).output_lines():
830         result.append('%s %s' % tuple(line.split(' ', 4)[-1].split('\t', 1)))
831
832     return '\n'.join(result)
833
834 def barefiles(rev1, rev2):
835     """Return the files modified between rev1 and rev2, without status info
836     """
837
838     result = []
839     for line in GRun('diff-tree', '-r', rev1, rev2).output_lines():
840         result.append(line.split(' ', 4)[-1].split('\t', 1)[-1])
841
842     return '\n'.join(result)
843
844 def pretty_commit(commit_id = 'HEAD', flags = []):
845     """Return a given commit (log + diff)
846     """
847     return GRun('show', *(flags + [commit_id])).raw_output()
848
849 def checkout(files = None, tree_id = None, force = False):
850     """Check out the given or all files
851     """
852     if tree_id:
853         try:
854             GRun('read-tree', '--reset', tree_id).run()
855         except GitRunException:
856             raise GitException, 'Failed "git read-tree" --reset %s' % tree_id
857
858     cmd = ['checkout-index', '-q', '-u']
859     if force:
860         cmd.append('-f')
861     if files:
862         GRun(*(cmd + ['--'])).xargs(files)
863     else:
864         GRun(*(cmd + ['-a'])).run()
865
866 def switch(tree_id, keep = False):
867     """Switch the tree to the given id
868     """
869     if keep:
870         # only update the index while keeping the local changes
871         GRun('read-tree', tree_id).run()
872     else:
873         refresh_index()
874         try:
875             GRun('read-tree', '-u', '-m', get_head(), tree_id).run()
876         except GitRunException:
877             raise GitException, 'read-tree failed (local changes maybe?)'
878
879     __set_head(tree_id)
880
881 def reset(files = None, tree_id = None, check_out = True):
882     """Revert the tree changes relative to the given tree_id. It removes
883     any local changes
884     """
885     if not tree_id:
886         tree_id = get_head()
887
888     if check_out:
889         cache_files = tree_status(files, tree_id)
890         # files which were added but need to be removed
891         rm_files =  [x[1] for x in cache_files if x[0] in ['A']]
892
893         checkout(files, tree_id, True)
894         # checkout doesn't remove files
895         map(os.remove, rm_files)
896
897     # if the reset refers to the whole tree, switch the HEAD as well
898     if not files:
899         __set_head(tree_id)
900
901 def fetch(repository = 'origin', refspec = None):
902     """Fetches changes from the remote repository, using 'git fetch'
903     by default.
904     """
905     # we update the HEAD
906     __clear_head_cache()
907
908     args = [repository]
909     if refspec:
910         args.append(refspec)
911
912     command = config.get('branch.%s.stgit.fetchcmd' % get_head_file()) or \
913               config.get('stgit.fetchcmd')
914     Run(*(command.split() + args)).run()
915
916 def pull(repository = 'origin', refspec = None):
917     """Fetches changes from the remote repository, using 'git pull'
918     by default.
919     """
920     # we update the HEAD
921     __clear_head_cache()
922
923     args = [repository]
924     if refspec:
925         args.append(refspec)
926
927     command = config.get('branch.%s.stgit.pullcmd' % get_head_file()) or \
928               config.get('stgit.pullcmd')
929     Run(*(command.split() + args)).run()
930
931 def rebase(tree_id = None):
932     """Rebase the current tree to the give tree_id. The tree_id
933     argument may be something other than a GIT id if an external
934     command is invoked.
935     """
936     command = config.get('branch.%s.stgit.rebasecmd' % get_head_file()) \
937                 or config.get('stgit.rebasecmd')
938     if tree_id:
939         args = [tree_id]
940     elif command:
941         args = []
942     else:
943         raise GitException, 'Default rebasing requires a commit id'
944     if command:
945         # clear the HEAD cache as the custom rebase command will update it
946         __clear_head_cache()
947         Run(*(command.split() + args)).run()
948     else:
949         # default rebasing
950         reset(tree_id = tree_id)
951
952 def repack():
953     """Repack all objects into a single pack
954     """
955     GRun('repack', '-a', '-d', '-f').run()
956
957 def apply_patch(filename = None, diff = None, base = None,
958                 fail_dump = True):
959     """Apply a patch onto the current or given index. There must not
960     be any local changes in the tree, otherwise the command fails
961     """
962     if diff is None:
963         if filename:
964             f = file(filename)
965         else:
966             f = sys.stdin
967         diff = f.read()
968         if filename:
969             f.close()
970
971     if base:
972         orig_head = get_head()
973         switch(base)
974     else:
975         refresh_index()
976
977     try:
978         GRun('apply', '--index').raw_input(diff).no_output()
979     except GitRunException:
980         if base:
981             switch(orig_head)
982         if fail_dump:
983             # write the failed diff to a file
984             f = file('.stgit-failed.patch', 'w+')
985             f.write(diff)
986             f.close()
987             out.warn('Diff written to the .stgit-failed.patch file')
988
989         raise
990
991     if base:
992         top = commit(message = 'temporary commit used for applying a patch',
993                      parents = [base])
994         switch(orig_head)
995         merge(base, orig_head, top)
996
997 def clone(repository, local_dir):
998     """Clone a remote repository. At the moment, just use the
999     'git clone' script
1000     """
1001     GRun('clone', repository, local_dir).run()
1002
1003 def modifying_revs(files, base_rev, head_rev):
1004     """Return the revisions from the list modifying the given files."""
1005     return GRun('rev-list', '%s..%s' % (base_rev, head_rev), '--', *files
1006                 ).output_lines()
1007
1008 def refspec_localpart(refspec):
1009     m = re.match('^[^:]*:([^:]*)$', refspec)
1010     if m:
1011         return m.group(1)
1012     else:
1013         raise GitException, 'Cannot parse refspec "%s"' % line
1014
1015 def refspec_remotepart(refspec):
1016     m = re.match('^([^:]*):[^:]*$', refspec)
1017     if m:
1018         return m.group(1)
1019     else:
1020         raise GitException, 'Cannot parse refspec "%s"' % line
1021
1022 def __remotes_from_config():
1023     return config.sections_matching(r'remote\.(.*)\.url')
1024
1025 def __remotes_from_dir(dir):
1026     d = os.path.join(basedir.get(), dir)
1027     if os.path.exists(d):
1028         return os.listdir(d)
1029     else:
1030         return []
1031
1032 def remotes_list():
1033     """Return the list of remotes in the repository
1034     """
1035     return (set(__remotes_from_config())
1036             | set(__remotes_from_dir('remotes'))
1037             | set(__remotes_from_dir('branches')))
1038
1039 def remotes_local_branches(remote):
1040     """Returns the list of local branches fetched from given remote
1041     """
1042
1043     branches = []
1044     if remote in __remotes_from_config():
1045         for line in config.getall('remote.%s.fetch' % remote):
1046             branches.append(refspec_localpart(line))
1047     elif remote in __remotes_from_dir('remotes'):
1048         stream = open(os.path.join(basedir.get(), 'remotes', remote), 'r')
1049         for line in stream:
1050             # Only consider Pull lines
1051             m = re.match('^Pull: (.*)\n$', line)
1052             if m:
1053                 branches.append(refspec_localpart(m.group(1)))
1054         stream.close()
1055     elif remote in __remotes_from_dir('branches'):
1056         # old-style branches only declare one branch
1057         branches.append('refs/heads/'+remote);
1058     else:
1059         raise GitException, 'Unknown remote "%s"' % remote
1060
1061     return branches
1062
1063 def identify_remote(branchname):
1064     """Return the name for the remote to pull the given branchname
1065     from, or None if we believe it is a local branch.
1066     """
1067
1068     for remote in remotes_list():
1069         if branchname in remotes_local_branches(remote):
1070             return remote
1071
1072     # if we get here we've found nothing, the branch is a local one
1073     return None
1074
1075 def fetch_head():
1076     """Return the git id for the tip of the parent branch as left by
1077     'git fetch'.
1078     """
1079
1080     fetch_head=None
1081     stream = open(os.path.join(basedir.get(), 'FETCH_HEAD'), "r")
1082     for line in stream:
1083         # Only consider lines not tagged not-for-merge
1084         m = re.match('^([^\t]*)\t\t', line)
1085         if m:
1086             if fetch_head:
1087                 raise GitException, 'StGit does not support multiple FETCH_HEAD'
1088             else:
1089                 fetch_head=m.group(1)
1090     stream.close()
1091
1092     if not fetch_head:
1093         out.warn('No for-merge remote head found in FETCH_HEAD')
1094
1095     # here we are sure to have a single fetch_head
1096     return fetch_head
1097
1098 def all_refs():
1099     """Return a list of all refs in the current repository.
1100     """
1101
1102     return [line.split()[1] for line in GRun('show-ref').output_lines()]