chiark / gitweb /
6140fd9136c1388d196417f9039cf8b71209c688
[stgit] / stgit / git.py
1 """Python GIT interface
2 """
3
4 __copyright__ = """
5 Copyright (C) 2005, Catalin Marinas <catalin.marinas@gmail.com>
6
7 This program is free software; you can redistribute it and/or modify
8 it under the terms of the GNU General Public License version 2 as
9 published by the Free Software Foundation.
10
11 This program is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14 GNU General Public License for more details.
15
16 You should have received a copy of the GNU General Public License
17 along with this program; if not, write to the Free Software
18 Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
19 """
20
21 import sys, os, re, gitmergeonefile
22 from shutil import copyfile
23
24 from stgit.exception import *
25 from stgit import basedir
26 from stgit.utils import *
27 from stgit.out import *
28 from stgit.run import *
29 from stgit.config import config
30
31 # git exception class
32 class GitException(StgException):
33     pass
34
35 # When a subprocess has a problem, we want the exception to be a
36 # subclass of GitException.
37 class GitRunException(GitException):
38     pass
39 class GRun(Run):
40     exc = GitRunException
41     def __init__(self, *cmd):
42         """Initialise the Run object and insert the 'git' command name.
43         """
44         Run.__init__(self, 'git', *cmd)
45
46
47 #
48 # Classes
49 #
50
51 class Person:
52     """An author, committer, etc."""
53     def __init__(self, name = None, email = None, date = '',
54                  desc = None):
55         self.name = self.email = self.date = None
56         if name or email or date:
57             assert not desc
58             self.name = name
59             self.email = email
60             self.date = date
61         elif desc:
62             assert not (name or email or date)
63             def parse_desc(s):
64                 m = re.match(r'^(.+)<(.+)>(.*)$', s)
65                 assert m
66                 return [x.strip() or None for x in m.groups()]
67             self.name, self.email, self.date = parse_desc(desc)
68     def set_name(self, val):
69         if val:
70             self.name = val
71     def set_email(self, val):
72         if val:
73             self.email = val
74     def set_date(self, val):
75         if val:
76             self.date = val
77     def __str__(self):
78         if self.name and self.email:
79             return '%s <%s>' % (self.name, self.email)
80         else:
81             raise GitException, 'not enough identity data'
82
83 class Commit:
84     """Handle the commit objects
85     """
86     def __init__(self, id_hash):
87         self.__id_hash = id_hash
88
89         lines = GRun('cat-file', 'commit', id_hash).output_lines()
90         for i in range(len(lines)):
91             line = lines[i]
92             if not line:
93                 break # we've seen all the header fields
94             key, val = line.split(' ', 1)
95             if key == 'tree':
96                 self.__tree = val
97             elif key == 'author':
98                 self.__author = val
99             elif key == 'committer':
100                 self.__committer = val
101             else:
102                 pass # ignore other headers
103         self.__log = '\n'.join(lines[i+1:])
104
105     def get_id_hash(self):
106         return self.__id_hash
107
108     def get_tree(self):
109         return self.__tree
110
111     def get_parent(self):
112         parents = self.get_parents()
113         if parents:
114             return parents[0]
115         else:
116             return None
117
118     def get_parents(self):
119         return GRun('rev-list', '--parents', '--max-count=1', self.__id_hash
120                     ).output_one_line().split()[1:]
121
122     def get_author(self):
123         return self.__author
124
125     def get_committer(self):
126         return self.__committer
127
128     def get_log(self):
129         return self.__log
130
131     def __str__(self):
132         return self.get_id_hash()
133
134 # dictionary of Commit objects, used to avoid multiple calls to git
135 __commits = dict()
136
137 #
138 # Functions
139 #
140
141 def get_commit(id_hash):
142     """Commit objects factory. Save/look-up them in the __commits
143     dictionary
144     """
145     global __commits
146
147     if id_hash in __commits:
148         return __commits[id_hash]
149     else:
150         commit = Commit(id_hash)
151         __commits[id_hash] = commit
152         return commit
153
154 def get_conflicts():
155     """Return the list of file conflicts
156     """
157     conflicts_file = os.path.join(basedir.get(), 'conflicts')
158     if os.path.isfile(conflicts_file):
159         f = file(conflicts_file)
160         names = [line.strip() for line in f.readlines()]
161         f.close()
162         return names
163     else:
164         return None
165
166 def exclude_files():
167     files = [os.path.join(basedir.get(), 'info', 'exclude')]
168     user_exclude = config.get('core.excludesfile')
169     if user_exclude:
170         files.append(user_exclude)
171     return files
172
173 def ls_files(files, tree = 'HEAD', full_name = True):
174     """Return the files known to GIT or raise an error otherwise. It also
175     converts the file to the full path relative the the .git directory.
176     """
177     if not files:
178         return []
179
180     args = []
181     if tree:
182         args.append('--with-tree=%s' % tree)
183     if full_name:
184         args.append('--full-name')
185     args.append('--')
186     args.extend(files)
187     try:
188         return GRun('ls-files', '--error-unmatch', *args).output_lines()
189     except GitRunException:
190         # just hide the details of the 'git ls-files' command we use
191         raise GitException, \
192             'Some of the given paths are either missing or not known to GIT'
193
194 def tree_status(files = None, tree_id = 'HEAD', unknown = False,
195                   noexclude = True, verbose = False, diff_flags = []):
196     """Get the status of all changed files, or of a selected set of
197     files. Returns a list of pairs - (status, filename).
198
199     If 'not files', it will check all files, and optionally all
200     unknown files.  If 'files' is a list, it will only check the files
201     in the list.
202     """
203     assert not files or not unknown
204
205     if verbose:
206         out.start('Checking for changes in the working directory')
207
208     refresh_index()
209
210     if files is None:
211         files = []
212     cache_files = []
213
214     # unknown files
215     if unknown:
216         cmd = ['ls-files', '-z', '--others', '--directory',
217                '--no-empty-directory']
218         if not noexclude:
219             cmd += ['--exclude=%s' % s for s in
220                     ['*.[ao]', '*.pyc', '.*', '*~', '#*', 'TAGS', 'tags']]
221             cmd += ['--exclude-per-directory=.gitignore']
222             cmd += ['--exclude-from=%s' % fn
223                     for fn in exclude_files()
224                     if os.path.exists(fn)]
225
226         lines = GRun(*cmd).raw_output().split('\0')
227         cache_files += [('?', line) for line in lines if line]
228
229     # conflicted files
230     conflicts = get_conflicts()
231     if not conflicts:
232         conflicts = []
233     cache_files += [('C', filename) for filename in conflicts
234                     if not files or filename in files]
235     reported_files = set(conflicts)
236     files_left = [f for f in files if f not in reported_files]
237
238     # files in the index. Only execute this code if no files were
239     # specified when calling the function (i.e. report all files) or
240     # files were specified but already found in the previous step
241     if not files or files_left:
242         args = diff_flags + [tree_id]
243         if files_left:
244             args += ['--'] + files_left
245         for line in GRun('diff-index', *args).output_lines():
246             fs = tuple(line.rstrip().split(' ',4)[-1].split('\t',1))
247             # the condition is needed in case files is emtpy and
248             # diff-index lists those already reported
249             if fs[1] not in reported_files:
250                 cache_files.append(fs)
251                 reported_files.add(fs[1])
252         files_left = [f for f in files if f not in reported_files]
253
254     # files in the index but changed on (or removed from) disk. Only
255     # execute this code if no files were specified when calling the
256     # function (i.e. report all files) or files were specified but
257     # already found in the previous step
258     if not files or files_left:
259         args = list(diff_flags)
260         if files_left:
261             args += ['--'] + files_left
262         for line in GRun('diff-files', *args).output_lines():
263             fs = tuple(line.rstrip().split(' ',4)[-1].split('\t',1))
264             # the condition is needed in case files is empty and
265             # diff-files lists those already reported
266             if fs[1] not in reported_files:
267                 cache_files.append(fs)
268                 reported_files.add(fs[1])
269
270     if verbose:
271         out.done()
272
273     return cache_files
274
275 def local_changes(verbose = True):
276     """Return true if there are local changes in the tree
277     """
278     return len(tree_status(verbose = verbose)) != 0
279
280 def get_heads():
281     heads = []
282     hr = re.compile(r'^[0-9a-f]{40} refs/heads/(.+)$')
283     for line in GRun('show-ref', '--heads').output_lines():
284         m = hr.match(line)
285         heads.append(m.group(1))
286     return heads
287
288 # HEAD value cached
289 __head = None
290
291 def get_head():
292     """Verifies the HEAD and returns the SHA1 id that represents it
293     """
294     global __head
295
296     if not __head:
297         __head = rev_parse('HEAD')
298     return __head
299
300 class DetachedHeadException(GitException):
301     def __init__(self):
302         GitException.__init__(self, 'Not on any branch')
303
304 def get_head_file():
305     """Return the name of the file pointed to by the HEAD symref.
306     Throw an exception if HEAD is detached."""
307     try:
308         return strip_prefix(
309             'refs/heads/', GRun('symbolic-ref', '-q', 'HEAD'
310                                 ).output_one_line())
311     except GitRunException:
312         raise DetachedHeadException()
313
314 def set_head_file(ref):
315     """Resets HEAD to point to a new ref
316     """
317     # head cache flushing is needed since we might have a different value
318     # in the new head
319     __clear_head_cache()
320     try:
321         GRun('symbolic-ref', 'HEAD', 'refs/heads/%s' % ref).run()
322     except GitRunException:
323         raise GitException, 'Could not set head to "%s"' % ref
324
325 def set_ref(ref, val):
326     """Point ref at a new commit object."""
327     try:
328         GRun('update-ref', ref, val).run()
329     except GitRunException:
330         raise GitException, 'Could not update %s to "%s".' % (ref, val)
331
332 def set_branch(branch, val):
333     set_ref('refs/heads/%s' % branch, val)
334
335 def __set_head(val):
336     """Sets the HEAD value
337     """
338     global __head
339
340     if not __head or __head != val:
341         set_ref('HEAD', val)
342         __head = val
343
344     # only allow SHA1 hashes
345     assert(len(__head) == 40)
346
347 def __clear_head_cache():
348     """Sets the __head to None so that a re-read is forced
349     """
350     global __head
351
352     __head = None
353
354 def refresh_index():
355     """Refresh index with stat() information from the working directory.
356     """
357     GRun('update-index', '-q', '--unmerged', '--refresh').run()
358
359 def rev_parse(git_id):
360     """Parse the string and return a verified SHA1 id
361     """
362     try:
363         return GRun('rev-parse', '--verify', git_id
364                     ).discard_stderr().output_one_line()
365     except GitRunException:
366         raise GitException, 'Unknown revision: %s' % git_id
367
368 def ref_exists(ref):
369     try:
370         rev_parse(ref)
371         return True
372     except GitException:
373         return False
374
375 def branch_exists(branch):
376     return ref_exists('refs/heads/%s' % branch)
377
378 def create_branch(new_branch, tree_id = None):
379     """Create a new branch in the git repository
380     """
381     if branch_exists(new_branch):
382         raise GitException, 'Branch "%s" already exists' % new_branch
383
384     current_head_file = get_head_file()
385     current_head = get_head()
386     set_head_file(new_branch)
387     __set_head(current_head)
388
389     # a checkout isn't needed if new branch points to the current head
390     if tree_id:
391         try:
392             switch(tree_id)
393         except GitException:
394             # Tree switching failed. Revert the head file
395             set_head_file(current_head_file)
396             delete_branch(new_branch)
397             raise
398
399     if os.path.isfile(os.path.join(basedir.get(), 'MERGE_HEAD')):
400         os.remove(os.path.join(basedir.get(), 'MERGE_HEAD'))
401
402 def switch_branch(new_branch):
403     """Switch to a git branch
404     """
405     global __head
406
407     if not branch_exists(new_branch):
408         raise GitException, 'Branch "%s" does not exist' % new_branch
409
410     tree_id = rev_parse('refs/heads/%s^{commit}' % new_branch)
411     if tree_id != get_head():
412         refresh_index()
413         try:
414             GRun('read-tree', '-u', '-m', get_head(), tree_id).run()
415         except GitRunException:
416             raise GitException, 'read-tree failed (local changes maybe?)'
417         __head = tree_id
418     set_head_file(new_branch)
419
420     if os.path.isfile(os.path.join(basedir.get(), 'MERGE_HEAD')):
421         os.remove(os.path.join(basedir.get(), 'MERGE_HEAD'))
422
423 def delete_ref(ref):
424     if not ref_exists(ref):
425         raise GitException, '%s does not exist' % ref
426     sha1 = GRun('show-ref', '-s', ref).output_one_line()
427     try:
428         GRun('update-ref', '-d', ref, sha1).run()
429     except GitRunException:
430         raise GitException, 'Failed to delete ref %s' % ref
431
432 def delete_branch(name):
433     delete_ref('refs/heads/%s' % name)
434
435 def rename_ref(from_ref, to_ref):
436     if not ref_exists(from_ref):
437         raise GitException, '"%s" does not exist' % from_ref
438     if ref_exists(to_ref):
439         raise GitException, '"%s" already exists' % to_ref
440
441     sha1 = GRun('show-ref', '-s', from_ref).output_one_line()
442     try:
443         GRun('update-ref', to_ref, sha1, '0'*40).run()
444     except GitRunException:
445         raise GitException, 'Failed to create new ref %s' % to_ref
446     try:
447         GRun('update-ref', '-d', from_ref, sha1).run()
448     except GitRunException:
449         raise GitException, 'Failed to delete ref %s' % from_ref
450
451 def rename_branch(from_name, to_name):
452     """Rename a git branch."""
453     rename_ref('refs/heads/%s' % from_name, 'refs/heads/%s' % to_name)
454     try:
455         if get_head_file() == from_name:
456             set_head_file(to_name)
457     except DetachedHeadException:
458         pass # detached HEAD, so the renamee can't be the current branch
459     reflog_dir = os.path.join(basedir.get(), 'logs', 'refs', 'heads')
460     if os.path.exists(reflog_dir) \
461            and os.path.exists(os.path.join(reflog_dir, from_name)):
462         rename(reflog_dir, from_name, to_name)
463
464 def add(names):
465     """Add the files or recursively add the directory contents
466     """
467     # generate the file list
468     files = []
469     for i in names:
470         if not os.path.exists(i):
471             raise GitException, 'Unknown file or directory: %s' % i
472
473         if os.path.isdir(i):
474             # recursive search. We only add files
475             for root, dirs, local_files in os.walk(i):
476                 for name in [os.path.join(root, f) for f in local_files]:
477                     if os.path.isfile(name):
478                         files.append(os.path.normpath(name))
479         elif os.path.isfile(i):
480             files.append(os.path.normpath(i))
481         else:
482             raise GitException, '%s is not a file or directory' % i
483
484     if files:
485         try:
486             GRun('update-index', '--add', '--').xargs(files)
487         except GitRunException:
488             raise GitException, 'Unable to add file'
489
490 def __copy_single(source, target, target2=''):
491     """Copy file or dir named 'source' to name target+target2"""
492
493     # "source" (file or dir) must match one or more git-controlled file
494     realfiles = GRun('ls-files', source).output_lines()
495     if len(realfiles) == 0:
496         raise GitException, '"%s" matches no git-controled files' % source
497
498     if os.path.isdir(source):
499         # physically copy the files, and record them to add them in one run
500         newfiles = []
501         re_string='^'+source+'/(.*)$'
502         prefix_regexp = re.compile(re_string)
503         for f in [f.strip() for f in realfiles]:
504             m = prefix_regexp.match(f)
505             if not m:
506                 raise Exception, '"%s" does not match "%s"' % (f, re_string)
507             newname = target+target2+'/'+m.group(1)
508             if not os.path.exists(os.path.dirname(newname)):
509                 os.makedirs(os.path.dirname(newname))
510             copyfile(f, newname)
511             newfiles.append(newname)
512
513         add(newfiles)
514     else: # files, symlinks, ...
515         newname = target+target2
516         copyfile(source, newname)
517         add([newname])
518
519
520 def copy(filespecs, target):
521     if os.path.isdir(target):
522         # target is a directory: copy each entry on the command line,
523         # with the same name, into the target
524         target = target.rstrip('/')
525         
526         # first, check that none of the children of the target
527         # matching the command line aleady exist
528         for filespec in filespecs:
529             entry = target+ '/' + os.path.basename(filespec.rstrip('/'))
530             if os.path.exists(entry):
531                 raise GitException, 'Target "%s" already exists' % entry
532         
533         for filespec in filespecs:
534             filespec = filespec.rstrip('/')
535             basename = '/' + os.path.basename(filespec)
536             __copy_single(filespec, target, basename)
537
538     elif os.path.exists(target):
539         raise GitException, 'Target "%s" exists but is not a directory' % target
540     elif len(filespecs) != 1:
541         raise GitException, 'Cannot copy more than one file to non-directory'
542
543     else:
544         # at this point: len(filespecs)==1 and target does not exist
545
546         # check target directory
547         targetdir = os.path.dirname(target)
548         if targetdir != '' and not os.path.isdir(targetdir):
549             raise GitException, 'Target directory "%s" does not exist' % targetdir
550
551         __copy_single(filespecs[0].rstrip('/'), target)
552         
553
554 def rm(files, force = False):
555     """Remove a file from the repository
556     """
557     if not force:
558         for f in files:
559             if os.path.exists(f):
560                 raise GitException, '%s exists. Remove it first' %f
561         if files:
562             GRun('update-index', '--remove', '--').xargs(files)
563     else:
564         if files:
565             GRun('update-index', '--force-remove', '--').xargs(files)
566
567 # Persons caching
568 __user = None
569 __author = None
570 __committer = None
571
572 def user():
573     """Return the user information.
574     """
575     global __user
576     if not __user:
577         name=config.get('user.name')
578         email=config.get('user.email')
579         __user = Person(name, email)
580     return __user;
581
582 def author():
583     """Return the author information.
584     """
585     global __author
586     if not __author:
587         try:
588             # the environment variables take priority over config
589             try:
590                 date = os.environ['GIT_AUTHOR_DATE']
591             except KeyError:
592                 date = ''
593             __author = Person(os.environ['GIT_AUTHOR_NAME'],
594                               os.environ['GIT_AUTHOR_EMAIL'],
595                               date)
596         except KeyError:
597             __author = user()
598     return __author
599
600 def committer():
601     """Return the author information.
602     """
603     global __committer
604     if not __committer:
605         try:
606             # the environment variables take priority over config
607             try:
608                 date = os.environ['GIT_COMMITTER_DATE']
609             except KeyError:
610                 date = ''
611             __committer = Person(os.environ['GIT_COMMITTER_NAME'],
612                                  os.environ['GIT_COMMITTER_EMAIL'],
613                                  date)
614         except KeyError:
615             __committer = user()
616     return __committer
617
618 def update_cache(files = None, force = False):
619     """Update the cache information for the given files
620     """
621     cache_files = tree_status(files, verbose = False)
622
623     # everything is up-to-date
624     if len(cache_files) == 0:
625         return False
626
627     # check for unresolved conflicts
628     if not force and [x for x in cache_files
629                       if x[0] not in ['M', 'N', 'A', 'D']]:
630         raise GitException, 'Updating cache failed: unresolved conflicts'
631
632     # update the cache
633     add_files = [x[1] for x in cache_files if x[0] in ['N', 'A']]
634     rm_files =  [x[1] for x in cache_files if x[0] in ['D']]
635     m_files =   [x[1] for x in cache_files if x[0] in ['M']]
636
637     GRun('update-index', '--add', '--').xargs(add_files)
638     GRun('update-index', '--force-remove', '--').xargs(rm_files)
639     GRun('update-index', '--').xargs(m_files)
640
641     return True
642
643 def commit(message, files = None, parents = None, allowempty = False,
644            cache_update = True, tree_id = None, set_head = False,
645            author_name = None, author_email = None, author_date = None,
646            committer_name = None, committer_email = None):
647     """Commit the current tree to repository
648     """
649     if not parents:
650         parents = []
651
652     # Get the tree status
653     if cache_update and parents != []:
654         changes = update_cache(files)
655         if not changes and not allowempty:
656             raise GitException, 'No changes to commit'
657
658     # get the commit message
659     if not message:
660         message = '\n'
661     elif message[-1:] != '\n':
662         message += '\n'
663
664     # write the index to repository
665     if tree_id == None:
666         tree_id = GRun('write-tree').output_one_line()
667         set_head = True
668
669     # the commit
670     env = {}
671     if author_name:
672         env['GIT_AUTHOR_NAME'] = author_name
673     if author_email:
674         env['GIT_AUTHOR_EMAIL'] = author_email
675     if author_date:
676         env['GIT_AUTHOR_DATE'] = author_date
677     if committer_name:
678         env['GIT_COMMITTER_NAME'] = committer_name
679     if committer_email:
680         env['GIT_COMMITTER_EMAIL'] = committer_email
681     commit_id = GRun('commit-tree', tree_id,
682                      *sum([['-p', p] for p in parents], [])
683                      ).env(env).raw_input(message).output_one_line()
684     if set_head:
685         __set_head(commit_id)
686
687     return commit_id
688
689 def apply_diff(rev1, rev2, check_index = True, files = None):
690     """Apply the diff between rev1 and rev2 onto the current
691     index. This function doesn't need to raise an exception since it
692     is only used for fast-pushing a patch. If this operation fails,
693     the pushing would fall back to the three-way merge.
694     """
695     if check_index:
696         index_opt = ['--index']
697     else:
698         index_opt = []
699
700     if not files:
701         files = []
702
703     diff_str = diff(files, rev1, rev2)
704     if diff_str:
705         try:
706             GRun('apply', *index_opt).raw_input(
707                 diff_str).discard_stderr().no_output()
708         except GitRunException:
709             return False
710
711     return True
712
713 def merge(base, head1, head2, recursive = False):
714     """Perform a 3-way merge between base, head1 and head2 into the
715     local tree
716     """
717     refresh_index()
718
719     err_output = None
720     if recursive:
721         # this operation tracks renames but it is slower (used in
722         # general when pushing or picking patches)
723         try:
724             # discard output to mask the verbose prints of the tool
725             GRun('merge-recursive', base, '--', head1, head2
726                  ).discard_output()
727         except GitRunException, ex:
728             err_output = str(ex)
729             pass
730     else:
731         # the fast case where we don't track renames (used when the
732         # distance between base and heads is small, i.e. folding or
733         # synchronising patches)
734         try:
735             GRun('read-tree', '-u', '-m', '--aggressive',
736                  base, head1, head2).run()
737         except GitRunException:
738             raise GitException, 'read-tree failed (local changes maybe?)'
739
740     # check the index for unmerged entries
741     files = {}
742     stages_re = re.compile('^([0-7]+) ([0-9a-f]{40}) ([1-3])\t(.*)$', re.S)
743
744     for line in GRun('ls-files', '--unmerged', '--stage', '-z'
745                      ).raw_output().split('\0'):
746         if not line:
747             continue
748
749         mode, hash, stage, path = stages_re.findall(line)[0]
750
751         if not path in files:
752             files[path] = {}
753             files[path]['1'] = ('', '')
754             files[path]['2'] = ('', '')
755             files[path]['3'] = ('', '')
756
757         files[path][stage] = (mode, hash)
758
759     if err_output and not files:
760         # if no unmerged files, there was probably a different type of
761         # error and we have to abort the merge
762         raise GitException, err_output
763
764     # merge the unmerged files
765     errors = False
766     for path in files:
767         # remove additional files that might be generated for some
768         # newer versions of GIT
769         for suffix in [base, head1, head2]:
770             if not suffix:
771                 continue
772             fname = path + '~' + suffix
773             if os.path.exists(fname):
774                 os.remove(fname)
775
776         stages = files[path]
777         if gitmergeonefile.merge(stages['1'][1], stages['2'][1],
778                                  stages['3'][1], path, stages['1'][0],
779                                  stages['2'][0], stages['3'][0]) != 0:
780             errors = True
781
782     if errors:
783         raise GitException, 'GIT index merging failed (possible conflicts)'
784
785 def diff(files = None, rev1 = 'HEAD', rev2 = None, diff_flags = [],
786          binary = True):
787     """Show the diff between rev1 and rev2
788     """
789     if not files:
790         files = []
791     if binary and '--binary' not in diff_flags:
792         diff_flags = diff_flags + ['--binary']
793
794     if rev1 and rev2:
795         return GRun('diff-tree', '-p',
796                     *(diff_flags + [rev1, rev2, '--'] + files)).raw_output()
797     elif rev1 or rev2:
798         refresh_index()
799         if rev2:
800             return GRun('diff-index', '-p', '-R',
801                         *(diff_flags + [rev2, '--'] + files)).raw_output()
802         else:
803             return GRun('diff-index', '-p',
804                         *(diff_flags + [rev1, '--'] + files)).raw_output()
805     else:
806         return ''
807
808 # TODO: take another parameter representing a diff string as we
809 # usually invoke git.diff() form the calling functions
810 def diffstat(files = None, rev1 = 'HEAD', rev2 = None):
811     """Return the diffstat between rev1 and rev2."""
812     return GRun('apply', '--stat', '--summary'
813                 ).raw_input(diff(files, rev1, rev2)).raw_output()
814
815 def files(rev1, rev2, diff_flags = []):
816     """Return the files modified between rev1 and rev2
817     """
818
819     result = []
820     for line in GRun('diff-tree', *(diff_flags + ['-r', rev1, rev2])
821                      ).output_lines():
822         result.append('%s %s' % tuple(line.split(' ', 4)[-1].split('\t', 1)))
823
824     return '\n'.join(result)
825
826 def barefiles(rev1, rev2):
827     """Return the files modified between rev1 and rev2, without status info
828     """
829
830     result = []
831     for line in GRun('diff-tree', '-r', rev1, rev2).output_lines():
832         result.append(line.split(' ', 4)[-1].split('\t', 1)[-1])
833
834     return '\n'.join(result)
835
836 def pretty_commit(commit_id = 'HEAD', flags = []):
837     """Return a given commit (log + diff)
838     """
839     return GRun('show', *(flags + [commit_id])).raw_output()
840
841 def checkout(files = None, tree_id = None, force = False):
842     """Check out the given or all files
843     """
844     if tree_id:
845         try:
846             GRun('read-tree', '--reset', tree_id).run()
847         except GitRunException:
848             raise GitException, 'Failed "git read-tree" --reset %s' % tree_id
849
850     cmd = ['checkout-index', '-q', '-u']
851     if force:
852         cmd.append('-f')
853     if files:
854         GRun(*(cmd + ['--'])).xargs(files)
855     else:
856         GRun(*(cmd + ['-a'])).run()
857
858 def switch(tree_id, keep = False):
859     """Switch the tree to the given id
860     """
861     if keep:
862         # only update the index while keeping the local changes
863         GRun('read-tree', tree_id).run()
864     else:
865         refresh_index()
866         try:
867             GRun('read-tree', '-u', '-m', get_head(), tree_id).run()
868         except GitRunException:
869             raise GitException, 'read-tree failed (local changes maybe?)'
870
871     __set_head(tree_id)
872
873 def reset(files = None, tree_id = None, check_out = True):
874     """Revert the tree changes relative to the given tree_id. It removes
875     any local changes
876     """
877     if not tree_id:
878         tree_id = get_head()
879
880     if check_out:
881         cache_files = tree_status(files, tree_id)
882         # files which were added but need to be removed
883         rm_files =  [x[1] for x in cache_files if x[0] in ['A']]
884
885         checkout(files, tree_id, True)
886         # checkout doesn't remove files
887         map(os.remove, rm_files)
888
889     # if the reset refers to the whole tree, switch the HEAD as well
890     if not files:
891         __set_head(tree_id)
892
893 def fetch(repository = 'origin', refspec = None):
894     """Fetches changes from the remote repository, using 'git fetch'
895     by default.
896     """
897     # we update the HEAD
898     __clear_head_cache()
899
900     args = [repository]
901     if refspec:
902         args.append(refspec)
903
904     command = config.get('branch.%s.stgit.fetchcmd' % get_head_file()) or \
905               config.get('stgit.fetchcmd')
906     Run(*(command.split() + args)).run()
907
908 def pull(repository = 'origin', refspec = None):
909     """Fetches changes from the remote repository, using 'git pull'
910     by default.
911     """
912     # we update the HEAD
913     __clear_head_cache()
914
915     args = [repository]
916     if refspec:
917         args.append(refspec)
918
919     command = config.get('branch.%s.stgit.pullcmd' % get_head_file()) or \
920               config.get('stgit.pullcmd')
921     Run(*(command.split() + args)).run()
922
923 def rebase(tree_id = None):
924     """Rebase the current tree to the give tree_id. The tree_id
925     argument may be something other than a GIT id if an external
926     command is invoked.
927     """
928     command = config.get('branch.%s.stgit.rebasecmd' % get_head_file()) \
929                 or config.get('stgit.rebasecmd')
930     if tree_id:
931         args = [tree_id]
932     elif command:
933         args = []
934     else:
935         raise GitException, 'Default rebasing requires a commit id'
936     if command:
937         # clear the HEAD cache as the custom rebase command will update it
938         __clear_head_cache()
939         Run(*(command.split() + args)).run()
940     else:
941         # default rebasing
942         reset(tree_id = tree_id)
943
944 def repack():
945     """Repack all objects into a single pack
946     """
947     GRun('repack', '-a', '-d', '-f').run()
948
949 def apply_patch(filename = None, diff = None, base = None,
950                 fail_dump = True):
951     """Apply a patch onto the current or given index. There must not
952     be any local changes in the tree, otherwise the command fails
953     """
954     if diff is None:
955         if filename:
956             f = file(filename)
957         else:
958             f = sys.stdin
959         diff = f.read()
960         if filename:
961             f.close()
962
963     if base:
964         orig_head = get_head()
965         switch(base)
966     else:
967         refresh_index()
968
969     try:
970         GRun('apply', '--index').raw_input(diff).no_output()
971     except GitRunException:
972         if base:
973             switch(orig_head)
974         if fail_dump:
975             # write the failed diff to a file
976             f = file('.stgit-failed.patch', 'w+')
977             f.write(diff)
978             f.close()
979             out.warn('Diff written to the .stgit-failed.patch file')
980
981         raise
982
983     if base:
984         top = commit(message = 'temporary commit used for applying a patch',
985                      parents = [base])
986         switch(orig_head)
987         merge(base, orig_head, top)
988
989 def clone(repository, local_dir):
990     """Clone a remote repository. At the moment, just use the
991     'git clone' script
992     """
993     GRun('clone', repository, local_dir).run()
994
995 def modifying_revs(files, base_rev, head_rev):
996     """Return the revisions from the list modifying the given files."""
997     return GRun('rev-list', '%s..%s' % (base_rev, head_rev), '--', *files
998                 ).output_lines()
999
1000 def refspec_localpart(refspec):
1001     m = re.match('^[^:]*:([^:]*)$', refspec)
1002     if m:
1003         return m.group(1)
1004     else:
1005         raise GitException, 'Cannot parse refspec "%s"' % line
1006
1007 def refspec_remotepart(refspec):
1008     m = re.match('^([^:]*):[^:]*$', refspec)
1009     if m:
1010         return m.group(1)
1011     else:
1012         raise GitException, 'Cannot parse refspec "%s"' % line
1013
1014 def __remotes_from_config():
1015     return config.sections_matching(r'remote\.(.*)\.url')
1016
1017 def __remotes_from_dir(dir):
1018     d = os.path.join(basedir.get(), dir)
1019     if os.path.exists(d):
1020         return os.listdir(d)
1021     else:
1022         return []
1023
1024 def remotes_list():
1025     """Return the list of remotes in the repository
1026     """
1027     return (set(__remotes_from_config())
1028             | set(__remotes_from_dir('remotes'))
1029             | set(__remotes_from_dir('branches')))
1030
1031 def remotes_local_branches(remote):
1032     """Returns the list of local branches fetched from given remote
1033     """
1034
1035     branches = []
1036     if remote in __remotes_from_config():
1037         for line in config.getall('remote.%s.fetch' % remote):
1038             branches.append(refspec_localpart(line))
1039     elif remote in __remotes_from_dir('remotes'):
1040         stream = open(os.path.join(basedir.get(), 'remotes', remote), 'r')
1041         for line in stream:
1042             # Only consider Pull lines
1043             m = re.match('^Pull: (.*)\n$', line)
1044             if m:
1045                 branches.append(refspec_localpart(m.group(1)))
1046         stream.close()
1047     elif remote in __remotes_from_dir('branches'):
1048         # old-style branches only declare one branch
1049         branches.append('refs/heads/'+remote);
1050     else:
1051         raise GitException, 'Unknown remote "%s"' % remote
1052
1053     return branches
1054
1055 def identify_remote(branchname):
1056     """Return the name for the remote to pull the given branchname
1057     from, or None if we believe it is a local branch.
1058     """
1059
1060     for remote in remotes_list():
1061         if branchname in remotes_local_branches(remote):
1062             return remote
1063
1064     # if we get here we've found nothing, the branch is a local one
1065     return None
1066
1067 def fetch_head():
1068     """Return the git id for the tip of the parent branch as left by
1069     'git fetch'.
1070     """
1071
1072     fetch_head=None
1073     stream = open(os.path.join(basedir.get(), 'FETCH_HEAD'), "r")
1074     for line in stream:
1075         # Only consider lines not tagged not-for-merge
1076         m = re.match('^([^\t]*)\t\t', line)
1077         if m:
1078             if fetch_head:
1079                 raise GitException, 'StGit does not support multiple FETCH_HEAD'
1080             else:
1081                 fetch_head=m.group(1)
1082     stream.close()
1083
1084     if not fetch_head:
1085         out.warn('No for-merge remote head found in FETCH_HEAD')
1086
1087     # here we are sure to have a single fetch_head
1088     return fetch_head
1089
1090 def all_refs():
1091     """Return a list of all refs in the current repository.
1092     """
1093
1094     return [line.split()[1] for line in GRun('show-ref').output_lines()]