chiark / gitweb /
74c2c108f3519f160d9d8d64d52e500f4284abdd
[stgit] / stgit / stack.py
1 """Basic quilt-like functionality
2 """
3
4 __copyright__ = """
5 Copyright (C) 2005, Catalin Marinas <catalin.marinas@gmail.com>
6
7 This program is free software; you can redistribute it and/or modify
8 it under the terms of the GNU General Public License version 2 as
9 published by the Free Software Foundation.
10
11 This program is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14 GNU General Public License for more details.
15
16 You should have received a copy of the GNU General Public License
17 along with this program; if not, write to the Free Software
18 Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
19 """
20
21 import sys, os, re
22 from email.Utils import formatdate
23
24 from stgit.exception import *
25 from stgit.utils import *
26 from stgit.out import *
27 from stgit.run import *
28 from stgit import git, basedir, templates
29 from stgit.config import config
30 from shutil import copyfile
31 from stgit.lib import git as libgit, stackupgrade
32
33 # stack exception class
34 class StackException(StgException):
35     pass
36
37 class FilterUntil:
38     def __init__(self):
39         self.should_print = True
40     def __call__(self, x, until_test, prefix):
41         if until_test(x):
42             self.should_print = False
43         if self.should_print:
44             return x[0:len(prefix)] != prefix
45         return False
46
47 #
48 # Functions
49 #
50 __comment_prefix = 'STG:'
51 __patch_prefix = 'STG_PATCH:'
52
53 def __clean_comments(f):
54     """Removes lines marked for status in a commit file
55     """
56     f.seek(0)
57
58     # remove status-prefixed lines
59     lines = f.readlines()
60
61     patch_filter = FilterUntil()
62     until_test = lambda t: t == (__patch_prefix + '\n')
63     lines = [l for l in lines if patch_filter(l, until_test, __comment_prefix)]
64
65     # remove empty lines at the end
66     while len(lines) != 0 and lines[-1] == '\n':
67         del lines[-1]
68
69     f.seek(0); f.truncate()
70     f.writelines(lines)
71
72 # TODO: move this out of the stgit.stack module, it is really for
73 # higher level commands to handle the user interaction
74 def edit_file(series, line, comment, show_patch = True):
75     fname = '.stgitmsg.txt'
76     tmpl = templates.get_template('patchdescr.tmpl')
77
78     f = file(fname, 'w+')
79     if line:
80         print >> f, line
81     elif tmpl:
82         print >> f, tmpl,
83     else:
84         print >> f
85     print >> f, __comment_prefix, comment
86     print >> f, __comment_prefix, \
87           'Lines prefixed with "%s" will be automatically removed.' \
88           % __comment_prefix
89     print >> f, __comment_prefix, \
90           'Trailing empty lines will be automatically removed.'
91
92     if show_patch:
93        print >> f, __patch_prefix
94        # series.get_patch(series.get_current()).get_top()
95        diff_str = git.diff(rev1 = series.get_patch(series.get_current()).get_bottom())
96        f.write(diff_str)
97
98     #Vim modeline must be near the end.
99     print >> f, __comment_prefix, 'vi: set textwidth=75 filetype=diff nobackup:'
100     f.close()
101
102     call_editor(fname)
103
104     f = file(fname, 'r+')
105
106     __clean_comments(f)
107     f.seek(0)
108     result = f.read()
109
110     f.close()
111     os.remove(fname)
112
113     return result
114
115 #
116 # Classes
117 #
118
119 class StgitObject:
120     """An object with stgit-like properties stored as files in a directory
121     """
122     def _set_dir(self, dir):
123         self.__dir = dir
124     def _dir(self):
125         return self.__dir
126
127     def create_empty_field(self, name):
128         create_empty_file(os.path.join(self.__dir, name))
129
130     def _get_field(self, name, multiline = False):
131         id_file = os.path.join(self.__dir, name)
132         if os.path.isfile(id_file):
133             line = read_string(id_file, multiline)
134             if line == '':
135                 return None
136             else:
137                 return line
138         else:
139             return None
140
141     def _set_field(self, name, value, multiline = False):
142         fname = os.path.join(self.__dir, name)
143         if value and value != '':
144             write_string(fname, value, multiline)
145         elif os.path.isfile(fname):
146             os.remove(fname)
147
148
149 class Patch(StgitObject):
150     """Basic patch implementation
151     """
152     def __init_refs(self):
153         self.__top_ref = self.__refs_base + '/' + self.__name
154         self.__log_ref = self.__top_ref + '.log'
155
156     def __init__(self, name, series_dir, refs_base):
157         self.__series_dir = series_dir
158         self.__name = name
159         self._set_dir(os.path.join(self.__series_dir, self.__name))
160         self.__refs_base = refs_base
161         self.__init_refs()
162
163     def create(self):
164         os.mkdir(self._dir())
165
166     def delete(self, keep_log = False):
167         if os.path.isdir(self._dir()):
168             for f in os.listdir(self._dir()):
169                 os.remove(os.path.join(self._dir(), f))
170             os.rmdir(self._dir())
171         else:
172             out.warn('Patch directory "%s" does not exist' % self._dir())
173         try:
174             # the reference might not exist if the repository was corrupted
175             git.delete_ref(self.__top_ref)
176         except git.GitException, e:
177             out.warn(str(e))
178         if not keep_log and git.ref_exists(self.__log_ref):
179             git.delete_ref(self.__log_ref)
180
181     def get_name(self):
182         return self.__name
183
184     def rename(self, newname):
185         olddir = self._dir()
186         old_top_ref = self.__top_ref
187         old_log_ref = self.__log_ref
188         self.__name = newname
189         self._set_dir(os.path.join(self.__series_dir, self.__name))
190         self.__init_refs()
191
192         git.rename_ref(old_top_ref, self.__top_ref)
193         if git.ref_exists(old_log_ref):
194             git.rename_ref(old_log_ref, self.__log_ref)
195         os.rename(olddir, self._dir())
196
197     def __update_top_ref(self, ref):
198         git.set_ref(self.__top_ref, ref)
199         self._set_field('top', ref)
200         self._set_field('bottom', git.get_commit(ref).get_parent())
201
202     def __update_log_ref(self, ref):
203         git.set_ref(self.__log_ref, ref)
204
205     def get_old_bottom(self):
206         return git.get_commit(self.get_old_top()).get_parent()
207
208     def get_bottom(self):
209         return git.get_commit(self.get_top()).get_parent()
210
211     def get_old_top(self):
212         return self._get_field('top.old')
213
214     def get_top(self):
215         return git.rev_parse(self.__top_ref)
216
217     def set_top(self, value, backup = False):
218         if backup:
219             curr_top = self.get_top()
220             self._set_field('top.old', curr_top)
221             self._set_field('bottom.old', git.get_commit(curr_top).get_parent())
222         self.__update_top_ref(value)
223
224     def restore_old_boundaries(self):
225         top = self._get_field('top.old')
226
227         if top:
228             self.__update_top_ref(top)
229             return True
230         else:
231             return False
232
233     def get_description(self):
234         return self._get_field('description', True)
235
236     def set_description(self, line):
237         self._set_field('description', line, True)
238
239     def get_authname(self):
240         return self._get_field('authname')
241
242     def set_authname(self, name):
243         self._set_field('authname', name or git.author().name)
244
245     def get_authemail(self):
246         return self._get_field('authemail')
247
248     def set_authemail(self, email):
249         self._set_field('authemail', email or git.author().email)
250
251     def get_authdate(self):
252         date = self._get_field('authdate')
253         if not date:
254             return date
255
256         if re.match('[0-9]+\s+[+-][0-9]+', date):
257             # Unix time (seconds) + time zone
258             secs_tz = date.split()
259             date = formatdate(int(secs_tz[0]))[:-5] + secs_tz[1]
260
261         return date
262
263     def set_authdate(self, date):
264         self._set_field('authdate', date or git.author().date)
265
266     def get_commname(self):
267         return self._get_field('commname')
268
269     def set_commname(self, name):
270         self._set_field('commname', name or git.committer().name)
271
272     def get_commemail(self):
273         return self._get_field('commemail')
274
275     def set_commemail(self, email):
276         self._set_field('commemail', email or git.committer().email)
277
278     def get_log(self):
279         return self._get_field('log')
280
281     def set_log(self, value, backup = False):
282         self._set_field('log', value)
283         self.__update_log_ref(value)
284
285 class PatchSet(StgitObject):
286     def __init__(self, name = None):
287         try:
288             if name:
289                 self.set_name (name)
290             else:
291                 self.set_name (git.get_head_file())
292             self.__base_dir = basedir.get()
293         except git.GitException, ex:
294             raise StackException, 'GIT tree not initialised: %s' % ex
295
296         self._set_dir(os.path.join(self.__base_dir, 'patches', self.get_name()))
297
298     def get_name(self):
299         return self.__name
300     def set_name(self, name):
301         self.__name = name
302
303     def _basedir(self):
304         return self.__base_dir
305
306     def get_head(self):
307         """Return the head of the branch
308         """
309         crt = self.get_current_patch()
310         if crt:
311             return crt.get_top()
312         else:
313             return self.get_base()
314
315     def get_protected(self):
316         return os.path.isfile(os.path.join(self._dir(), 'protected'))
317
318     def protect(self):
319         protect_file = os.path.join(self._dir(), 'protected')
320         if not os.path.isfile(protect_file):
321             create_empty_file(protect_file)
322
323     def unprotect(self):
324         protect_file = os.path.join(self._dir(), 'protected')
325         if os.path.isfile(protect_file):
326             os.remove(protect_file)
327
328     def __branch_descr(self):
329         return 'branch.%s.description' % self.get_name()
330
331     def get_description(self):
332         return config.get(self.__branch_descr()) or ''
333
334     def set_description(self, line):
335         if line:
336             config.set(self.__branch_descr(), line)
337         else:
338             config.unset(self.__branch_descr())
339
340     def head_top_equal(self):
341         """Return true if the head and the top are the same
342         """
343         crt = self.get_current_patch()
344         if not crt:
345             # we don't care, no patches applied
346             return True
347         return git.get_head() == crt.get_top()
348
349     def is_initialised(self):
350         """Checks if series is already initialised
351         """
352         return config.get(stackupgrade.format_version_key(self.get_name())
353                           ) != None
354
355
356 def shortlog(patches):
357     log = ''.join(Run('git', 'log', '--pretty=short',
358                       p.get_top(), '^%s' % p.get_bottom()).raw_output()
359                   for p in patches)
360     return Run('git', 'shortlog').raw_input(log).raw_output()
361
362 class Series(PatchSet):
363     """Class including the operations on series
364     """
365     def __init__(self, name = None):
366         """Takes a series name as the parameter.
367         """
368         PatchSet.__init__(self, name)
369
370         # Update the branch to the latest format version if it is
371         # initialized, but don't touch it if it isn't.
372         stackupgrade.update_to_current_format_version(
373             libgit.Repository.default(), self.get_name())
374
375         self.__refs_base = 'refs/patches/%s' % self.get_name()
376
377         self.__applied_file = os.path.join(self._dir(), 'applied')
378         self.__unapplied_file = os.path.join(self._dir(), 'unapplied')
379         self.__hidden_file = os.path.join(self._dir(), 'hidden')
380
381         # where this series keeps its patches
382         self.__patch_dir = os.path.join(self._dir(), 'patches')
383
384         # trash directory
385         self.__trash_dir = os.path.join(self._dir(), 'trash')
386
387     def __patch_name_valid(self, name):
388         """Raise an exception if the patch name is not valid.
389         """
390         if not name or re.search('[^\w.-]', name):
391             raise StackException, 'Invalid patch name: "%s"' % name
392
393     def get_patch(self, name):
394         """Return a Patch object for the given name
395         """
396         return Patch(name, self.__patch_dir, self.__refs_base)
397
398     def get_current_patch(self):
399         """Return a Patch object representing the topmost patch, or
400         None if there is no such patch."""
401         crt = self.get_current()
402         if not crt:
403             return None
404         return self.get_patch(crt)
405
406     def get_current(self):
407         """Return the name of the topmost patch, or None if there is
408         no such patch."""
409         try:
410             applied = self.get_applied()
411         except StackException:
412             # No "applied" file: branch is not initialized.
413             return None
414         try:
415             return applied[-1]
416         except IndexError:
417             # No patches applied.
418             return None
419
420     def get_applied(self):
421         if not os.path.isfile(self.__applied_file):
422             raise StackException, 'Branch "%s" not initialised' % self.get_name()
423         return read_strings(self.__applied_file)
424
425     def set_applied(self, applied):
426         write_strings(self.__applied_file, applied)
427
428     def get_unapplied(self):
429         if not os.path.isfile(self.__unapplied_file):
430             raise StackException, 'Branch "%s" not initialised' % self.get_name()
431         return read_strings(self.__unapplied_file)
432
433     def set_unapplied(self, unapplied):
434         write_strings(self.__unapplied_file, unapplied)
435
436     def get_hidden(self):
437         if not os.path.isfile(self.__hidden_file):
438             return []
439         return read_strings(self.__hidden_file)
440
441     def get_base(self):
442         # Return the parent of the bottommost patch, if there is one.
443         if os.path.isfile(self.__applied_file):
444             bottommost = file(self.__applied_file).readline().strip()
445             if bottommost:
446                 return self.get_patch(bottommost).get_bottom()
447         # No bottommost patch, so just return HEAD
448         return git.get_head()
449
450     def get_parent_remote(self):
451         value = config.get('branch.%s.remote' % self.get_name())
452         if value:
453             return value
454         elif 'origin' in git.remotes_list():
455             out.note(('No parent remote declared for stack "%s",'
456                       ' defaulting to "origin".' % self.get_name()),
457                      ('Consider setting "branch.%s.remote" and'
458                       ' "branch.%s.merge" with "git config".'
459                       % (self.get_name(), self.get_name())))
460             return 'origin'
461         else:
462             raise StackException, 'Cannot find a parent remote for "%s"' % self.get_name()
463
464     def __set_parent_remote(self, remote):
465         value = config.set('branch.%s.remote' % self.get_name(), remote)
466
467     def get_parent_branch(self):
468         value = config.get('branch.%s.stgit.parentbranch' % self.get_name())
469         if value:
470             return value
471         elif git.rev_parse('heads/origin'):
472             out.note(('No parent branch declared for stack "%s",'
473                       ' defaulting to "heads/origin".' % self.get_name()),
474                      ('Consider setting "branch.%s.stgit.parentbranch"'
475                       ' with "git config".' % self.get_name()))
476             return 'heads/origin'
477         else:
478             raise StackException, 'Cannot find a parent branch for "%s"' % self.get_name()
479
480     def __set_parent_branch(self, name):
481         if config.get('branch.%s.remote' % self.get_name()):
482             # Never set merge if remote is not set to avoid
483             # possibly-erroneous lookups into 'origin'
484             config.set('branch.%s.merge' % self.get_name(), name)
485         config.set('branch.%s.stgit.parentbranch' % self.get_name(), name)
486
487     def set_parent(self, remote, localbranch):
488         if localbranch:
489             if remote:
490                 self.__set_parent_remote(remote)
491             self.__set_parent_branch(localbranch)
492         # We'll enforce this later
493 #         else:
494 #             raise StackException, 'Parent branch (%s) should be specified for %s' % localbranch, self.get_name()
495
496     def __patch_is_current(self, patch):
497         return patch.get_name() == self.get_current()
498
499     def patch_applied(self, name):
500         """Return true if the patch exists in the applied list
501         """
502         return name in self.get_applied()
503
504     def patch_unapplied(self, name):
505         """Return true if the patch exists in the unapplied list
506         """
507         return name in self.get_unapplied()
508
509     def patch_hidden(self, name):
510         """Return true if the patch is hidden.
511         """
512         return name in self.get_hidden()
513
514     def patch_exists(self, name):
515         """Return true if there is a patch with the given name, false
516         otherwise."""
517         return self.patch_applied(name) or self.patch_unapplied(name) \
518                or self.patch_hidden(name)
519
520     def init(self, create_at=False, parent_remote=None, parent_branch=None):
521         """Initialises the stgit series
522         """
523         if self.is_initialised():
524             raise StackException, '%s already initialized' % self.get_name()
525         for d in [self._dir()]:
526             if os.path.exists(d):
527                 raise StackException, '%s already exists' % d
528
529         if (create_at!=False):
530             git.create_branch(self.get_name(), create_at)
531
532         os.makedirs(self.__patch_dir)
533
534         self.set_parent(parent_remote, parent_branch)
535
536         self.create_empty_field('applied')
537         self.create_empty_field('unapplied')
538
539         config.set(stackupgrade.format_version_key(self.get_name()),
540                    str(stackupgrade.FORMAT_VERSION))
541
542     def rename(self, to_name):
543         """Renames a series
544         """
545         to_stack = Series(to_name)
546
547         if to_stack.is_initialised():
548             raise StackException, '"%s" already exists' % to_stack.get_name()
549
550         patches = self.get_applied() + self.get_unapplied()
551
552         git.rename_branch(self.get_name(), to_name)
553
554         for patch in patches:
555             git.rename_ref('refs/patches/%s/%s' % (self.get_name(), patch),
556                            'refs/patches/%s/%s' % (to_name, patch))
557             git.rename_ref('refs/patches/%s/%s.log' % (self.get_name(), patch),
558                            'refs/patches/%s/%s.log' % (to_name, patch))
559         if os.path.isdir(self._dir()):
560             rename(os.path.join(self._basedir(), 'patches'),
561                    self.get_name(), to_stack.get_name())
562
563         # Rename the config section
564         for k in ['branch.%s', 'branch.%s.stgit']:
565             config.rename_section(k % self.get_name(), k % to_name)
566
567         self.__init__(to_name)
568
569     def clone(self, target_series):
570         """Clones a series
571         """
572         try:
573             # allow cloning of branches not under StGIT control
574             base = self.get_base()
575         except:
576             base = git.get_head()
577         Series(target_series).init(create_at = base)
578         new_series = Series(target_series)
579
580         # generate an artificial description file
581         new_series.set_description('clone of "%s"' % self.get_name())
582
583         # clone self's entire series as unapplied patches
584         try:
585             # allow cloning of branches not under StGIT control
586             applied = self.get_applied()
587             unapplied = self.get_unapplied()
588             patches = applied + unapplied
589             patches.reverse()
590         except:
591             patches = applied = unapplied = []
592         for p in patches:
593             patch = self.get_patch(p)
594             newpatch = new_series.new_patch(p, message = patch.get_description(),
595                                             can_edit = False, unapplied = True,
596                                             bottom = patch.get_bottom(),
597                                             top = patch.get_top(),
598                                             author_name = patch.get_authname(),
599                                             author_email = patch.get_authemail(),
600                                             author_date = patch.get_authdate())
601             if patch.get_log():
602                 out.info('Setting log to %s' %  patch.get_log())
603                 newpatch.set_log(patch.get_log())
604             else:
605                 out.info('No log for %s' % p)
606
607         # fast forward the cloned series to self's top
608         new_series.forward_patches(applied)
609
610         # Clone parent informations
611         value = config.get('branch.%s.remote' % self.get_name())
612         if value:
613             config.set('branch.%s.remote' % target_series, value)
614
615         value = config.get('branch.%s.merge' % self.get_name())
616         if value:
617             config.set('branch.%s.merge' % target_series, value)
618
619         value = config.get('branch.%s.stgit.parentbranch' % self.get_name())
620         if value:
621             config.set('branch.%s.stgit.parentbranch' % target_series, value)
622
623     def delete(self, force = False):
624         """Deletes an stgit series
625         """
626         if self.is_initialised():
627             patches = self.get_unapplied() + self.get_applied()
628             if not force and patches:
629                 raise StackException, \
630                       'Cannot delete: the series still contains patches'
631             for p in patches:
632                 self.get_patch(p).delete()
633
634             # remove the trash directory if any
635             if os.path.exists(self.__trash_dir):
636                 for fname in os.listdir(self.__trash_dir):
637                     os.remove(os.path.join(self.__trash_dir, fname))
638                 os.rmdir(self.__trash_dir)
639
640             # FIXME: find a way to get rid of those manual removals
641             # (move functionality to StgitObject ?)
642             if os.path.exists(self.__applied_file):
643                 os.remove(self.__applied_file)
644             if os.path.exists(self.__unapplied_file):
645                 os.remove(self.__unapplied_file)
646             if os.path.exists(self.__hidden_file):
647                 os.remove(self.__hidden_file)
648             if os.path.exists(self._dir()+'/orig-base'):
649                 os.remove(self._dir()+'/orig-base')
650
651             if not os.listdir(self.__patch_dir):
652                 os.rmdir(self.__patch_dir)
653             else:
654                 out.warn('Patch directory %s is not empty' % self.__patch_dir)
655
656             try:
657                 os.removedirs(self._dir())
658             except OSError:
659                 raise StackException('Series directory %s is not empty'
660                                      % self._dir())
661
662         try:
663             git.delete_branch(self.get_name())
664         except git.GitException:
665             out.warn('Could not delete branch "%s"' % self.get_name())
666
667         config.remove_section('branch.%s' % self.get_name())
668         config.remove_section('branch.%s.stgit' % self.get_name())
669
670     def refresh_patch(self, files = None, message = None, edit = False,
671                       empty = False,
672                       show_patch = False,
673                       cache_update = True,
674                       author_name = None, author_email = None,
675                       author_date = None,
676                       committer_name = None, committer_email = None,
677                       backup = True, sign_str = None, log = 'refresh',
678                       notes = None, bottom = None):
679         """Generates a new commit for the topmost patch
680         """
681         patch = self.get_current_patch()
682         if not patch:
683             raise StackException, 'No patches applied'
684
685         descr = patch.get_description()
686         if not (message or descr):
687             edit = True
688             descr = ''
689         elif message:
690             descr = message
691
692         # TODO: move this out of the stgit.stack module, it is really
693         # for higher level commands to handle the user interaction
694         if not message and edit:
695             descr = edit_file(self, descr.rstrip(), \
696                               'Please edit the description for patch "%s" ' \
697                               'above.' % patch.get_name(), show_patch)
698
699         if not author_name:
700             author_name = patch.get_authname()
701         if not author_email:
702             author_email = patch.get_authemail()
703         if not committer_name:
704             committer_name = patch.get_commname()
705         if not committer_email:
706             committer_email = patch.get_commemail()
707
708         descr = add_sign_line(descr, sign_str, committer_name, committer_email)
709
710         if not bottom:
711             bottom = patch.get_bottom()
712
713         if empty:
714             tree_id = git.get_commit(bottom).get_tree()
715         else:
716             tree_id = None
717
718         commit_id = git.commit(files = files,
719                                message = descr, parents = [bottom],
720                                cache_update = cache_update,
721                                tree_id = tree_id,
722                                set_head = True,
723                                allowempty = True,
724                                author_name = author_name,
725                                author_email = author_email,
726                                author_date = author_date,
727                                committer_name = committer_name,
728                                committer_email = committer_email)
729
730         patch.set_top(commit_id, backup = backup)
731         patch.set_description(descr)
732         patch.set_authname(author_name)
733         patch.set_authemail(author_email)
734         patch.set_authdate(author_date)
735         patch.set_commname(committer_name)
736         patch.set_commemail(committer_email)
737
738         if log:
739             self.log_patch(patch, log, notes)
740
741         return commit_id
742
743     def undo_refresh(self):
744         """Undo the patch boundaries changes caused by 'refresh'
745         """
746         name = self.get_current()
747         assert(name)
748
749         patch = self.get_patch(name)
750         old_bottom = patch.get_old_bottom()
751         old_top = patch.get_old_top()
752
753         # the bottom of the patch is not changed by refresh. If the
754         # old_bottom is different, there wasn't any previous 'refresh'
755         # command (probably only a 'push')
756         if old_bottom != patch.get_bottom() or old_top == patch.get_top():
757             raise StackException, 'No undo information available'
758
759         git.reset(tree_id = old_top, check_out = False)
760         if patch.restore_old_boundaries():
761             self.log_patch(patch, 'undo')
762
763     def new_patch(self, name, message = None, can_edit = True,
764                   unapplied = False, show_patch = False,
765                   top = None, bottom = None, commit = True,
766                   author_name = None, author_email = None, author_date = None,
767                   committer_name = None, committer_email = None,
768                   before_existing = False, sign_str = None):
769         """Creates a new patch, either pointing to an existing commit object,
770         or by creating a new commit object.
771         """
772
773         assert commit or (top and bottom)
774         assert not before_existing or (top and bottom)
775         assert not (commit and before_existing)
776         assert (top and bottom) or (not top and not bottom)
777         assert commit or (not top or (bottom == git.get_commit(top).get_parent()))
778
779         if name != None:
780             self.__patch_name_valid(name)
781             if self.patch_exists(name):
782                 raise StackException, 'Patch "%s" already exists' % name
783
784         # TODO: move this out of the stgit.stack module, it is really
785         # for higher level commands to handle the user interaction
786         def sign(msg):
787             return add_sign_line(msg, sign_str,
788                                  committer_name or git.committer().name,
789                                  committer_email or git.committer().email)
790         if not message and can_edit:
791             descr = edit_file(
792                 self, sign(''),
793                 'Please enter the description for the patch above.',
794                 show_patch)
795         else:
796             descr = sign(message)
797
798         head = git.get_head()
799
800         if name == None:
801             name = make_patch_name(descr, self.patch_exists)
802
803         patch = self.get_patch(name)
804         patch.create()
805
806         patch.set_description(descr)
807         patch.set_authname(author_name)
808         patch.set_authemail(author_email)
809         patch.set_authdate(author_date)
810         patch.set_commname(committer_name)
811         patch.set_commemail(committer_email)
812
813         if before_existing:
814             insert_string(self.__applied_file, patch.get_name())
815         elif unapplied:
816             patches = [patch.get_name()] + self.get_unapplied()
817             write_strings(self.__unapplied_file, patches)
818             set_head = False
819         else:
820             append_string(self.__applied_file, patch.get_name())
821             set_head = True
822
823         if commit:
824             if top:
825                 top_commit = git.get_commit(top)
826             else:
827                 bottom = head
828                 top_commit = git.get_commit(head)
829
830             # create a commit for the patch (may be empty if top == bottom);
831             # only commit on top of the current branch
832             assert(unapplied or bottom == head)
833             commit_id = git.commit(message = descr, parents = [bottom],
834                                    cache_update = False,
835                                    tree_id = top_commit.get_tree(),
836                                    allowempty = True, set_head = set_head,
837                                    author_name = author_name,
838                                    author_email = author_email,
839                                    author_date = author_date,
840                                    committer_name = committer_name,
841                                    committer_email = committer_email)
842             # set the patch top to the new commit
843             patch.set_top(commit_id)
844         else:
845             patch.set_top(top)
846
847         self.log_patch(patch, 'new')
848
849         return patch
850
851     def delete_patch(self, name, keep_log = False):
852         """Deletes a patch
853         """
854         self.__patch_name_valid(name)
855         patch = self.get_patch(name)
856
857         if self.__patch_is_current(patch):
858             self.pop_patch(name)
859         elif self.patch_applied(name):
860             raise StackException, 'Cannot remove an applied patch, "%s", ' \
861                   'which is not current' % name
862         elif not name in self.get_unapplied():
863             raise StackException, 'Unknown patch "%s"' % name
864
865         # save the commit id to a trash file
866         write_string(os.path.join(self.__trash_dir, name), patch.get_top())
867
868         patch.delete(keep_log = keep_log)
869
870         unapplied = self.get_unapplied()
871         unapplied.remove(name)
872         write_strings(self.__unapplied_file, unapplied)
873
874     def forward_patches(self, names):
875         """Try to fast-forward an array of patches.
876
877         On return, patches in names[0:returned_value] have been pushed on the
878         stack. Apply the rest with push_patch
879         """
880         unapplied = self.get_unapplied()
881
882         forwarded = 0
883         top = git.get_head()
884
885         for name in names:
886             assert(name in unapplied)
887
888             patch = self.get_patch(name)
889
890             head = top
891             bottom = patch.get_bottom()
892             top = patch.get_top()
893
894             # top != bottom always since we have a commit for each patch
895             if head == bottom:
896                 # reset the backup information. No logging since the
897                 # patch hasn't changed
898                 patch.set_top(top, backup = True)
899
900             else:
901                 head_tree = git.get_commit(head).get_tree()
902                 bottom_tree = git.get_commit(bottom).get_tree()
903                 if head_tree == bottom_tree:
904                     # We must just reparent this patch and create a new commit
905                     # for it
906                     descr = patch.get_description()
907                     author_name = patch.get_authname()
908                     author_email = patch.get_authemail()
909                     author_date = patch.get_authdate()
910                     committer_name = patch.get_commname()
911                     committer_email = patch.get_commemail()
912
913                     top_tree = git.get_commit(top).get_tree()
914
915                     top = git.commit(message = descr, parents = [head],
916                                      cache_update = False,
917                                      tree_id = top_tree,
918                                      allowempty = True,
919                                      author_name = author_name,
920                                      author_email = author_email,
921                                      author_date = author_date,
922                                      committer_name = committer_name,
923                                      committer_email = committer_email)
924
925                     patch.set_top(top, backup = True)
926
927                     self.log_patch(patch, 'push(f)')
928                 else:
929                     top = head
930                     # stop the fast-forwarding, must do a real merge
931                     break
932
933             forwarded+=1
934             unapplied.remove(name)
935
936         if forwarded == 0:
937             return 0
938
939         git.switch(top)
940
941         append_strings(self.__applied_file, names[0:forwarded])
942         write_strings(self.__unapplied_file, unapplied)
943
944         return forwarded
945
946     def merged_patches(self, names):
947         """Test which patches were merged upstream by reverse-applying
948         them in reverse order. The function returns the list of
949         patches detected to have been applied. The state of the tree
950         is restored to the original one
951         """
952         patches = [self.get_patch(name) for name in names]
953         patches.reverse()
954
955         merged = []
956         for p in patches:
957             if git.apply_diff(p.get_top(), p.get_bottom()):
958                 merged.append(p.get_name())
959         merged.reverse()
960
961         git.reset()
962
963         return merged
964
965     def push_empty_patch(self, name):
966         """Pushes an empty patch on the stack
967         """
968         unapplied = self.get_unapplied()
969         assert(name in unapplied)
970
971         # patch = self.get_patch(name)
972         head = git.get_head()
973
974         append_string(self.__applied_file, name)
975
976         unapplied.remove(name)
977         write_strings(self.__unapplied_file, unapplied)
978
979         self.refresh_patch(bottom = head, cache_update = False, log = 'push(m)')
980
981     def push_patch(self, name):
982         """Pushes a patch on the stack
983         """
984         unapplied = self.get_unapplied()
985         assert(name in unapplied)
986
987         patch = self.get_patch(name)
988
989         head = git.get_head()
990         bottom = patch.get_bottom()
991         top = patch.get_top()
992         # top != bottom always since we have a commit for each patch
993
994         if head == bottom:
995             # A fast-forward push. Just reset the backup
996             # information. No need for logging
997             patch.set_top(top, backup = True)
998
999             git.switch(top)
1000             append_string(self.__applied_file, name)
1001
1002             unapplied.remove(name)
1003             write_strings(self.__unapplied_file, unapplied)
1004             return False
1005
1006         # Need to create a new commit an merge in the old patch
1007         ex = None
1008         modified = False
1009
1010         # Try the fast applying first. If this fails, fall back to the
1011         # three-way merge
1012         if not git.apply_diff(bottom, top):
1013             # if git.apply_diff() fails, the patch requires a diff3
1014             # merge and can be reported as modified
1015             modified = True
1016
1017             # merge can fail but the patch needs to be pushed
1018             try:
1019                 git.merge_recursive(bottom, head, top)
1020             except git.GitException, ex:
1021                 out.error('The merge failed during "push".',
1022                           'Revert the operation with "push --undo".')
1023
1024         append_string(self.__applied_file, name)
1025
1026         unapplied.remove(name)
1027         write_strings(self.__unapplied_file, unapplied)
1028
1029         if not ex:
1030             # if the merge was OK and no conflicts, just refresh the patch
1031             # The GIT cache was already updated by the merge operation
1032             if modified:
1033                 log = 'push(m)'
1034             else:
1035                 log = 'push'
1036             self.refresh_patch(bottom = head, cache_update = False, log = log)
1037         else:
1038             # we make the patch empty, with the merged state in the
1039             # working tree.
1040             self.refresh_patch(bottom = head, cache_update = False,
1041                                empty = True, log = 'push(c)')
1042             raise StackException, str(ex)
1043
1044         return modified
1045
1046     def undo_push(self):
1047         name = self.get_current()
1048         assert(name)
1049
1050         patch = self.get_patch(name)
1051         old_bottom = patch.get_old_bottom()
1052         old_top = patch.get_old_top()
1053
1054         # the top of the patch is changed by a push operation only
1055         # together with the bottom (otherwise the top was probably
1056         # modified by 'refresh'). If they are both unchanged, there
1057         # was a fast forward
1058         if old_bottom == patch.get_bottom() and old_top != patch.get_top():
1059             raise StackException, 'No undo information available'
1060
1061         git.reset()
1062         self.pop_patch(name)
1063         ret = patch.restore_old_boundaries()
1064         if ret:
1065             self.log_patch(patch, 'undo')
1066
1067         return ret
1068
1069     def pop_patch(self, name, keep = False):
1070         """Pops the top patch from the stack
1071         """
1072         applied = self.get_applied()
1073         applied.reverse()
1074         assert(name in applied)
1075
1076         patch = self.get_patch(name)
1077
1078         if git.get_head_file() == self.get_name():
1079             if keep and not git.apply_diff(git.get_head(), patch.get_bottom(),
1080                                            check_index = False):
1081                 raise StackException(
1082                     'Failed to pop patches while preserving the local changes')
1083             git.switch(patch.get_bottom(), keep)
1084         else:
1085             git.set_branch(self.get_name(), patch.get_bottom())
1086
1087         # save the new applied list
1088         idx = applied.index(name) + 1
1089
1090         popped = applied[:idx]
1091         popped.reverse()
1092         unapplied = popped + self.get_unapplied()
1093         write_strings(self.__unapplied_file, unapplied)
1094
1095         del applied[:idx]
1096         applied.reverse()
1097         write_strings(self.__applied_file, applied)
1098
1099     def empty_patch(self, name):
1100         """Returns True if the patch is empty
1101         """
1102         self.__patch_name_valid(name)
1103         patch = self.get_patch(name)
1104         bottom = patch.get_bottom()
1105         top = patch.get_top()
1106
1107         if bottom == top:
1108             return True
1109         elif git.get_commit(top).get_tree() \
1110                  == git.get_commit(bottom).get_tree():
1111             return True
1112
1113         return False
1114
1115     def rename_patch(self, oldname, newname):
1116         self.__patch_name_valid(newname)
1117
1118         applied = self.get_applied()
1119         unapplied = self.get_unapplied()
1120
1121         if oldname == newname:
1122             raise StackException, '"To" name and "from" name are the same'
1123
1124         if newname in applied or newname in unapplied:
1125             raise StackException, 'Patch "%s" already exists' % newname
1126
1127         if oldname in unapplied:
1128             self.get_patch(oldname).rename(newname)
1129             unapplied[unapplied.index(oldname)] = newname
1130             write_strings(self.__unapplied_file, unapplied)
1131         elif oldname in applied:
1132             self.get_patch(oldname).rename(newname)
1133
1134             applied[applied.index(oldname)] = newname
1135             write_strings(self.__applied_file, applied)
1136         else:
1137             raise StackException, 'Unknown patch "%s"' % oldname
1138
1139     def log_patch(self, patch, message, notes = None):
1140         """Generate a log commit for a patch
1141         """
1142         top = git.get_commit(patch.get_top())
1143         old_log = patch.get_log()
1144
1145         if message is None:
1146             # replace the current log entry
1147             if not old_log:
1148                 raise StackException, \
1149                       'No log entry to annotate for patch "%s"' \
1150                       % patch.get_name()
1151             replace = True
1152             log_commit = git.get_commit(old_log)
1153             msg = log_commit.get_log().split('\n')[0]
1154             log_parent = log_commit.get_parent()
1155             if log_parent:
1156                 parents = [log_parent]
1157             else:
1158                 parents = []
1159         else:
1160             # generate a new log entry
1161             replace = False
1162             msg = '%s\t%s' % (message, top.get_id_hash())
1163             if old_log:
1164                 parents = [old_log]
1165             else:
1166                 parents = []
1167
1168         if notes:
1169             msg += '\n\n' + notes
1170
1171         log = git.commit(message = msg, parents = parents,
1172                          cache_update = False, tree_id = top.get_tree(),
1173                          allowempty = True)
1174         patch.set_log(log)
1175
1176     def hide_patch(self, name):
1177         """Add the patch to the hidden list.
1178         """
1179         unapplied = self.get_unapplied()
1180         if name not in unapplied:
1181             # keep the checking order for backward compatibility with
1182             # the old hidden patches functionality
1183             if self.patch_applied(name):
1184                 raise StackException, 'Cannot hide applied patch "%s"' % name
1185             elif self.patch_hidden(name):
1186                 raise StackException, 'Patch "%s" already hidden' % name
1187             else:
1188                 raise StackException, 'Unknown patch "%s"' % name
1189
1190         if not self.patch_hidden(name):
1191             # check needed for backward compatibility with the old
1192             # hidden patches functionality
1193             append_string(self.__hidden_file, name)
1194
1195         unapplied.remove(name)
1196         write_strings(self.__unapplied_file, unapplied)
1197
1198     def unhide_patch(self, name):
1199         """Remove the patch from the hidden list.
1200         """
1201         hidden = self.get_hidden()
1202         if not name in hidden:
1203             if self.patch_applied(name) or self.patch_unapplied(name):
1204                 raise StackException, 'Patch "%s" not hidden' % name
1205             else:
1206                 raise StackException, 'Unknown patch "%s"' % name
1207
1208         hidden.remove(name)
1209         write_strings(self.__hidden_file, hidden)
1210
1211         if not self.patch_applied(name) and not self.patch_unapplied(name):
1212             # check needed for backward compatibility with the old
1213             # hidden patches functionality
1214             append_string(self.__unapplied_file, name)