chiark / gitweb /
Allow 'stg status --reset' to work on individual files
[stgit] / stgit / git.py
1 """Python GIT interface
2 """
3
4 __copyright__ = """
5 Copyright (C) 2005, Catalin Marinas <catalin.marinas@gmail.com>
6
7 This program is free software; you can redistribute it and/or modify
8 it under the terms of the GNU General Public License version 2 as
9 published by the Free Software Foundation.
10
11 This program is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14 GNU General Public License for more details.
15
16 You should have received a copy of the GNU General Public License
17 along with this program; if not, write to the Free Software
18 Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
19 """
20
21 import sys, os, popen2
22
23 from stgit.utils import *
24
25 # git exception class
26 class GitException(Exception):
27     pass
28
29
30 # Different start-up variables read from the environment
31 if 'GIT_DIR' in os.environ:
32     base_dir = os.environ['GIT_DIR']
33 else:
34     base_dir = '.git'
35
36
37 #
38 # Classes
39 #
40 class Commit:
41     """Handle the commit objects
42     """
43     def __init__(self, id_hash):
44         self.__id_hash = id_hash
45
46         lines = _output_lines('git-cat-file commit %s' % id_hash)
47         self.__parents = []
48         for i in range(len(lines)):
49             line = lines[i]
50             if line == '\n':
51                 break
52             field = line.strip().split(' ', 1)
53             if field[0] == 'tree':
54                 self.__tree = field[1]
55             elif field[0] == 'parent':
56                 self.__parents.append(field[1])
57             if field[0] == 'author':
58                 self.__author = field[1]
59             if field[0] == 'committer':
60                 self.__committer = field[1]
61         self.__log = ''.join(lines[i+1:])
62
63     def get_id_hash(self):
64         return self.__id_hash
65
66     def get_tree(self):
67         return self.__tree
68
69     def get_parent(self):
70         return self.__parents[0]
71
72     def get_parents(self):
73         return self.__parents
74
75     def get_author(self):
76         return self.__author
77
78     def get_committer(self):
79         return self.__committer
80
81     def get_log(self):
82         return self.__log
83
84 # dictionary of Commit objects, used to avoid multiple calls to git
85 __commits = dict()
86
87 #
88 # Functions
89 #
90 def get_commit(id_hash):
91     """Commit objects factory. Save/look-up them in the __commits
92     dictionary
93     """
94     global __commits
95
96     if id_hash in __commits:
97         return __commits[id_hash]
98     else:
99         commit = Commit(id_hash)
100         __commits[id_hash] = commit
101         return commit
102
103 def get_conflicts():
104     """Return the list of file conflicts
105     """
106     conflicts_file = os.path.join(base_dir, 'conflicts')
107     if os.path.isfile(conflicts_file):
108         f = file(conflicts_file)
109         names = [line.strip() for line in f.readlines()]
110         f.close()
111         return names
112     else:
113         return None
114
115 def _input(cmd, file_desc):
116     p = popen2.Popen3(cmd, True)
117     while True:
118         line = file_desc.readline()
119         if not line:
120             break
121         p.tochild.write(line)
122     p.tochild.close()
123     if p.wait():
124         raise GitException, '%s failed' % str(cmd)
125
126 def _output(cmd):
127     p=popen2.Popen3(cmd, True)
128     output = p.fromchild.read()
129     if p.wait():
130         raise GitException, '%s failed' % str(cmd)
131     return output
132
133 def _output_one_line(cmd, file_desc = None):
134     p=popen2.Popen3(cmd, True)
135     if file_desc != None:
136         for line in file_desc:
137             p.tochild.write(line)
138         p.tochild.close()
139     output = p.fromchild.readline().strip()
140     if p.wait():
141         raise GitException, '%s failed' % str(cmd)
142     return output
143
144 def _output_lines(cmd):
145     p=popen2.Popen3(cmd, True)
146     lines = p.fromchild.readlines()
147     if p.wait():
148         raise GitException, '%s failed' % str(cmd)
149     return lines
150
151 def __run(cmd, args=None):
152     """__run: runs cmd using spawnvp.
153
154     Runs cmd using spawnvp.  The shell is avoided so it won't mess up
155     our arguments.  If args is very large, the command is run multiple
156     times; args is split xargs style: cmd is passed on each
157     invocation.  Unlike xargs, returns immediately if any non-zero
158     return code is received.  
159     """
160     
161     args_l=cmd.split()
162     if args is None:
163         args = []
164     for i in range(0, len(args)+1, 100):
165         r=os.spawnvp(os.P_WAIT, args_l[0], args_l + args[i:min(i+100, len(args))])
166     if r:
167         return r
168     return 0
169
170 def __check_base_dir():
171     return os.path.isdir(base_dir)
172
173 def __tree_status(files = None, tree_id = 'HEAD', unknown = False,
174                   noexclude = True):
175     """Returns a list of pairs - [status, filename]
176     """
177     refresh_index()
178
179     if not files:
180         files = []
181     cache_files = []
182
183     # unknown files
184     if unknown:
185         exclude_file = os.path.join(base_dir, 'info', 'exclude')
186         base_exclude = ['--exclude=%s' % s for s in
187                         ['*.[ao]', '*.pyc', '.*', '*~', '#*', 'TAGS', 'tags']]
188         base_exclude.append('--exclude-per-directory=.gitignore')
189
190         if os.path.exists(exclude_file):
191             extra_exclude = ['--exclude-from=%s' % exclude_file]
192         else:
193             extra_exclude = []
194         if noexclude:
195             extra_exclude = base_exclude = []
196
197         lines = _output_lines(['git-ls-files', '--others'] + base_exclude
198                         + extra_exclude)
199         cache_files += [('?', line.strip()) for line in lines]
200
201     # conflicted files
202     conflicts = get_conflicts()
203     if not conflicts:
204         conflicts = []
205     cache_files += [('C', filename) for filename in conflicts]
206
207     # the rest
208     for line in _output_lines(['git-diff-index', '-r', tree_id] + files):
209         fs = tuple(line.rstrip().split(' ',4)[-1].split('\t',1))
210         if fs[1] not in conflicts:
211             cache_files.append(fs)
212
213     return cache_files
214
215 def local_changes():
216     """Return true if there are local changes in the tree
217     """
218     return len(__tree_status()) != 0
219
220 # HEAD value cached
221 __head = None
222
223 def get_head():
224     """Verifies the HEAD and returns the SHA1 id that represents it
225     """
226     global __head
227
228     if not __head:
229         __head = rev_parse('HEAD')
230     return __head
231
232 def get_head_file():
233     """Returns the name of the file pointed to by the HEAD link
234     """
235     return os.path.basename(_output_one_line('git-symbolic-ref HEAD'))
236
237 def set_head_file(ref):
238     """Resets HEAD to point to a new ref
239     """
240     # head cache flushing is needed since we might have a different value
241     # in the new head
242     __clear_head_cache()
243     if __run('git-symbolic-ref HEAD', [ref]) != 0:
244         raise GitException, 'Could not set head to "%s"' % ref
245
246 def __set_head(val):
247     """Sets the HEAD value
248     """
249     global __head
250
251     if not __head or __head != val:
252         if __run('git-update-ref HEAD', [val]) != 0:
253             raise GitException, 'Could not update HEAD to "%s".' % val
254         __head = val
255
256 def __clear_head_cache():
257     """Sets the __head to None so that a re-read is forced
258     """
259     global __head
260
261     __head = None
262
263 def refresh_index():
264     """Refresh index with stat() information from the working directory.
265     """
266     __run('git-update-index -q --unmerged --refresh')
267
268 def rev_parse(git_id):
269     """Parse the string and return a verified SHA1 id
270     """
271     try:
272         return _output_one_line(['git-rev-parse', '--verify', git_id])
273     except GitException:
274         raise GitException, 'Unknown revision: %s' % git_id
275
276 def branch_exists(branch):
277     """Existance check for the named branch
278     """
279     for line in _output_lines(['git-rev-parse', '--symbolic', '--all']):
280         if line.strip() == branch:
281             return True
282     return False
283
284 def create_branch(new_branch, tree_id = None):
285     """Create a new branch in the git repository
286     """
287     new_head = os.path.join('refs', 'heads', new_branch)
288     if branch_exists(new_head):
289         raise GitException, 'Branch "%s" already exists' % new_branch
290
291     current_head = get_head()
292     set_head_file(new_head)
293     __set_head(current_head)
294
295     # a checkout isn't needed if new branch points to the current head
296     if tree_id:
297         switch(tree_id)
298
299     if os.path.isfile(os.path.join(base_dir, 'MERGE_HEAD')):
300         os.remove(os.path.join(base_dir, 'MERGE_HEAD'))
301
302 def switch_branch(name):
303     """Switch to a git branch
304     """
305     global __head
306
307     new_head = os.path.join('refs', 'heads', name)
308     if not branch_exists(new_head):
309         raise GitException, 'Branch "%s" does not exist' % name
310
311     tree_id = rev_parse(new_head + '^0')
312     if tree_id != get_head():
313         refresh_index()
314         if __run('git-read-tree -u -m', [get_head(), tree_id]) != 0:
315             raise GitException, 'git-read-tree failed (local changes maybe?)'
316         __head = tree_id
317     set_head_file(new_head)
318
319     if os.path.isfile(os.path.join(base_dir, 'MERGE_HEAD')):
320         os.remove(os.path.join(base_dir, 'MERGE_HEAD'))
321
322 def delete_branch(name):
323     """Delete a git branch
324     """
325     branch_head = os.path.join('refs', 'heads', name)
326     if not branch_exists(branch_head):
327         raise GitException, 'Branch "%s" does not exist' % name
328     os.remove(os.path.join(base_dir, branch_head))
329
330 def rename_branch(from_name, to_name):
331     """Rename a git branch
332     """
333     from_head = os.path.join('refs', 'heads', from_name)
334     if not branch_exists(from_head):
335         raise GitException, 'Branch "%s" does not exist' % from_name
336     to_head = os.path.join('refs', 'heads', to_name)
337     if branch_exists(to_head):
338         raise GitException, 'Branch "%s" already exists' % to_name
339
340     if get_head_file() == from_name:
341         set_head_file(to_head)
342     os.rename(os.path.join(base_dir, from_head), os.path.join(base_dir, to_head))
343
344 def add(names):
345     """Add the files or recursively add the directory contents
346     """
347     # generate the file list
348     files = []
349     for i in names:
350         if not os.path.exists(i):
351             raise GitException, 'Unknown file or directory: %s' % i
352
353         if os.path.isdir(i):
354             # recursive search. We only add files
355             for root, dirs, local_files in os.walk(i):
356                 for name in [os.path.join(root, f) for f in local_files]:
357                     if os.path.isfile(name):
358                         files.append(os.path.normpath(name))
359         elif os.path.isfile(i):
360             files.append(os.path.normpath(i))
361         else:
362             raise GitException, '%s is not a file or directory' % i
363
364     if files:
365         if __run('git-update-index --add --', files):
366             raise GitException, 'Unable to add file'
367
368 def rm(files, force = False):
369     """Remove a file from the repository
370     """
371     if not force:
372         for f in files:
373             if os.path.exists(f):
374                 raise GitException, '%s exists. Remove it first' %f
375         if files:
376             __run('git-update-index --remove --', files)
377     else:
378         if files:
379             __run('git-update-index --force-remove --', files)
380
381 def update_cache(files = None, force = False):
382     """Update the cache information for the given files
383     """
384     if not files:
385         files = []
386
387     cache_files = __tree_status(files)
388
389     # everything is up-to-date
390     if len(cache_files) == 0:
391         return False
392
393     # check for unresolved conflicts
394     if not force and [x for x in cache_files
395                       if x[0] not in ['M', 'N', 'A', 'D']]:
396         raise GitException, 'Updating cache failed: unresolved conflicts'
397
398     # update the cache
399     add_files = [x[1] for x in cache_files if x[0] in ['N', 'A']]
400     rm_files =  [x[1] for x in cache_files if x[0] in ['D']]
401     m_files =   [x[1] for x in cache_files if x[0] in ['M']]
402
403     if add_files and __run('git-update-index --add --', add_files) != 0:
404         raise GitException, 'Failed git-update-index --add'
405     if rm_files and __run('git-update-index --force-remove --', rm_files) != 0:
406         raise GitException, 'Failed git-update-index --rm'
407     if m_files and __run('git-update-index --', m_files) != 0:
408         raise GitException, 'Failed git-update-index'
409
410     return True
411
412 def commit(message, files = None, parents = None, allowempty = False,
413            cache_update = True, tree_id = None,
414            author_name = None, author_email = None, author_date = None,
415            committer_name = None, committer_email = None):
416     """Commit the current tree to repository
417     """
418     if not files:
419         files = []
420     if not parents:
421         parents = []
422
423     # Get the tree status
424     if cache_update and parents != []:
425         changes = update_cache(files)
426         if not changes and not allowempty:
427             raise GitException, 'No changes to commit'
428
429     # get the commit message
430     if message[-1:] != '\n':
431         message += '\n'
432
433     must_switch = True
434     # write the index to repository
435     if tree_id == None:
436         tree_id = _output_one_line('git-write-tree')
437     else:
438         must_switch = False
439
440     # the commit
441     cmd = ''
442     if author_name:
443         cmd += 'GIT_AUTHOR_NAME="%s" ' % author_name
444     if author_email:
445         cmd += 'GIT_AUTHOR_EMAIL="%s" ' % author_email
446     if author_date:
447         cmd += 'GIT_AUTHOR_DATE="%s" ' % author_date
448     if committer_name:
449         cmd += 'GIT_COMMITTER_NAME="%s" ' % committer_name
450     if committer_email:
451         cmd += 'GIT_COMMITTER_EMAIL="%s" ' % committer_email
452     cmd += 'git-commit-tree %s' % tree_id
453
454     # get the parents
455     for p in parents:
456         cmd += ' -p %s' % p
457
458     commit_id = _output_one_line(cmd, message)
459     if must_switch:
460         __set_head(commit_id)
461
462     return commit_id
463
464 def apply_diff(rev1, rev2):
465     """Apply the diff between rev1 and rev2 onto the current
466     index. This function doesn't need to raise an exception since it
467     is only used for fast-pushing a patch. If this operation fails,
468     the pushing would fall back to the three-way merge.
469     """
470     return os.system('git-diff-tree -p %s %s | git-apply --index 2> /dev/null'
471                      % (rev1, rev2)) == 0
472
473 def merge(base, head1, head2):
474     """Perform a 3-way merge between base, head1 and head2 into the
475     local tree
476     """
477     refresh_index()
478     if __run('git-read-tree -u -m', [base, head1, head2]) != 0:
479         raise GitException, 'git-read-tree failed (local changes maybe?)'
480
481     # this can fail if there are conflicts
482     if __run('git-merge-index -o -q gitmergeonefile.py -a') != 0:
483         raise GitException, 'git-merge-index failed (possible conflicts)'
484
485 def status(files = None, modified = False, new = False, deleted = False,
486            conflict = False, unknown = False, noexclude = False):
487     """Show the tree status
488     """
489     if not files:
490         files = []
491
492     cache_files = __tree_status(files, unknown = True, noexclude = noexclude)
493     all = not (modified or new or deleted or conflict or unknown)
494
495     if not all:
496         filestat = []
497         if modified:
498             filestat.append('M')
499         if new:
500             filestat.append('A')
501             filestat.append('N')
502         if deleted:
503             filestat.append('D')
504         if conflict:
505             filestat.append('C')
506         if unknown:
507             filestat.append('?')
508         cache_files = [x for x in cache_files if x[0] in filestat]
509
510     for fs in cache_files:
511         if all:
512             print '%s %s' % (fs[0], fs[1])
513         else:
514             print '%s' % fs[1]
515
516 def diff(files = None, rev1 = 'HEAD', rev2 = None, out_fd = None):
517     """Show the diff between rev1 and rev2
518     """
519     if not files:
520         files = []
521
522     if rev1 and rev2:
523         diff_str = _output(['git-diff-tree', '-p', rev1, rev2] + files)
524     elif rev1 or rev2:
525         refresh_index()
526         if rev2:
527             diff_str = _output(['git-diff-index', '-p', '-R', rev2] + files)
528         else:
529             diff_str = _output(['git-diff-index', '-p', rev1] + files)
530     else:
531         diff_str = ''
532
533     if out_fd:
534         out_fd.write(diff_str)
535     else:
536         return diff_str
537
538 def diffstat(files = None, rev1 = 'HEAD', rev2 = None):
539     """Return the diffstat between rev1 and rev2
540     """
541     if not files:
542         files = []
543
544     p=popen2.Popen3('git-apply --stat')
545     diff(files, rev1, rev2, p.tochild)
546     p.tochild.close()
547     diff_str = p.fromchild.read().rstrip()
548     if p.wait():
549         raise GitException, 'git.diffstat failed'
550     return diff_str
551
552 def files(rev1, rev2):
553     """Return the files modified between rev1 and rev2
554     """
555
556     result = ''
557     for line in _output_lines('git-diff-tree -r %s %s' % (rev1, rev2)):
558         result += '%s %s\n' % tuple(line.rstrip().split(' ',4)[-1].split('\t',1))
559
560     return result.rstrip()
561
562 def barefiles(rev1, rev2):
563     """Return the files modified between rev1 and rev2, without status info
564     """
565
566     result = ''
567     for line in _output_lines('git-diff-tree -r %s %s' % (rev1, rev2)):
568         result += '%s\n' % line.rstrip().split(' ',4)[-1].split('\t',1)[-1]
569
570     return result.rstrip()
571
572 def checkout(files = None, tree_id = None, force = False):
573     """Check out the given or all files
574     """
575     if not files:
576         files = []
577
578     if tree_id and __run('git-read-tree -m', [tree_id]) != 0:
579         raise GitException, 'Failed git-read-tree -m %s' % tree_id
580
581     checkout_cmd = 'git-checkout-index -q -u'
582     if force:
583         checkout_cmd += ' -f'
584     if len(files) == 0:
585         checkout_cmd += ' -a'
586     else:
587         checkout_cmd += ' --'
588
589     if __run(checkout_cmd, files) != 0:
590         raise GitException, 'Failed git-checkout-index'
591
592 def switch(tree_id):
593     """Switch the tree to the given id
594     """
595     refresh_index()
596     if __run('git-read-tree -u -m', [get_head(), tree_id]) != 0:
597         raise GitException, 'git-read-tree failed (local changes maybe?)'
598
599     __set_head(tree_id)
600
601 def reset(files = None, tree_id = None):
602     """Revert the tree changes relative to the given tree_id. It removes
603     any local changes
604     """
605     if not tree_id:
606         tree_id = get_head()
607
608     cache_files = __tree_status(files, tree_id)
609     rm_files =  [x[1] for x in cache_files if x[0] in ['D']]
610
611     checkout(files, tree_id, True)
612     # checkout doesn't remove files
613     map(os.remove, rm_files)
614
615     # if the reset refers to the whole tree, switch the HEAD as well
616     if not files:
617         __set_head(tree_id)
618
619 def pull(repository = 'origin', refspec = None):
620     """Pull changes from the remote repository. At the moment, just
621     use the 'git pull' command
622     """
623     # 'git pull' updates the HEAD
624     __clear_head_cache()
625
626     args = [repository]
627     if refspec:
628         args.append(refspec)
629
630     if __run('git pull', args) != 0:
631         raise GitException, 'Failed "git pull %s"' % repository
632
633 def apply_patch(filename = None, base = None):
634     """Apply a patch onto the current or given index. There must not
635     be any local changes in the tree, otherwise the command fails
636     """
637     def __apply_patch():
638         if filename:
639             return __run('git-apply --index', [filename]) == 0
640         else:
641             try:
642                 _input('git-apply --index', sys.stdin)
643             except GitException:
644                 return False
645             return True
646
647     if base:
648         orig_head = get_head()
649         switch(base)
650     else:
651         refresh_index()         # needed since __apply_patch() doesn't do it
652
653     if not __apply_patch():
654         if base:
655             switch(orig_head)
656         raise GitException, 'Patch does not apply cleanly'
657     elif base:
658         top = commit(message = 'temporary commit used for applying a patch',
659                      parents = [base])
660         switch(orig_head)
661         merge(base, orig_head, top)
662
663 def clone(repository, local_dir):
664     """Clone a remote repository. At the moment, just use the
665     'git clone' script
666     """
667     if __run('git clone', [repository, local_dir]) != 0:
668         raise GitException, 'Failed "git clone %s %s"' \
669               % (repository, local_dir)
670
671 def modifying_revs(files, base_rev):
672     """Return the revisions from the list modifying the given files
673     """
674     cmd = ['git-rev-list', '%s..' % base_rev, '--']
675     revs = [line.strip() for line in _output_lines(cmd + files)]
676
677     return revs