chiark / gitweb /
d889f37797c18aee446ae9362f67b17af172675e
[stgit] / stgit / stack.py
1 """Basic quilt-like functionality
2 """
3
4 __copyright__ = """
5 Copyright (C) 2005, Catalin Marinas <catalin.marinas@gmail.com>
6
7 This program is free software; you can redistribute it and/or modify
8 it under the terms of the GNU General Public License version 2 as
9 published by the Free Software Foundation.
10
11 This program is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14 GNU General Public License for more details.
15
16 You should have received a copy of the GNU General Public License
17 along with this program; if not, write to the Free Software
18 Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
19 """
20
21 import sys, os, re
22 from email.Utils import formatdate
23
24 from stgit.utils import *
25 from stgit.out import *
26 from stgit.run import *
27 from stgit import git, basedir, templates
28 from stgit.config import config
29 from shutil import copyfile
30
31
32 # stack exception class
33 class StackException(Exception):
34     pass
35
36 class FilterUntil:
37     def __init__(self):
38         self.should_print = True
39     def __call__(self, x, until_test, prefix):
40         if until_test(x):
41             self.should_print = False
42         if self.should_print:
43             return x[0:len(prefix)] != prefix
44         return False
45
46 #
47 # Functions
48 #
49 __comment_prefix = 'STG:'
50 __patch_prefix = 'STG_PATCH:'
51
52 def __clean_comments(f):
53     """Removes lines marked for status in a commit file
54     """
55     f.seek(0)
56
57     # remove status-prefixed lines
58     lines = f.readlines()
59
60     patch_filter = FilterUntil()
61     until_test = lambda t: t == (__patch_prefix + '\n')
62     lines = [l for l in lines if patch_filter(l, until_test, __comment_prefix)]
63
64     # remove empty lines at the end
65     while len(lines) != 0 and lines[-1] == '\n':
66         del lines[-1]
67
68     f.seek(0); f.truncate()
69     f.writelines(lines)
70
71 # TODO: move this out of the stgit.stack module, it is really for
72 # higher level commands to handle the user interaction
73 def edit_file(series, line, comment, show_patch = True):
74     fname = '.stgitmsg.txt'
75     tmpl = templates.get_template('patchdescr.tmpl')
76
77     f = file(fname, 'w+')
78     if line:
79         print >> f, line
80     elif tmpl:
81         print >> f, tmpl,
82     else:
83         print >> f
84     print >> f, __comment_prefix, comment
85     print >> f, __comment_prefix, \
86           'Lines prefixed with "%s" will be automatically removed.' \
87           % __comment_prefix
88     print >> f, __comment_prefix, \
89           'Trailing empty lines will be automatically removed.'
90
91     if show_patch:
92        print >> f, __patch_prefix
93        # series.get_patch(series.get_current()).get_top()
94        diff_str = git.diff(rev1 = series.get_patch(series.get_current()).get_bottom())
95        f.write(diff_str)
96
97     #Vim modeline must be near the end.
98     print >> f, __comment_prefix, 'vi: set textwidth=75 filetype=diff nobackup:'
99     f.close()
100
101     call_editor(fname)
102
103     f = file(fname, 'r+')
104
105     __clean_comments(f)
106     f.seek(0)
107     result = f.read()
108
109     f.close()
110     os.remove(fname)
111
112     return result
113
114 #
115 # Classes
116 #
117
118 class StgitObject:
119     """An object with stgit-like properties stored as files in a directory
120     """
121     def _set_dir(self, dir):
122         self.__dir = dir
123     def _dir(self):
124         return self.__dir
125
126     def create_empty_field(self, name):
127         create_empty_file(os.path.join(self.__dir, name))
128
129     def _get_field(self, name, multiline = False):
130         id_file = os.path.join(self.__dir, name)
131         if os.path.isfile(id_file):
132             line = read_string(id_file, multiline)
133             if line == '':
134                 return None
135             else:
136                 return line
137         else:
138             return None
139
140     def _set_field(self, name, value, multiline = False):
141         fname = os.path.join(self.__dir, name)
142         if value and value != '':
143             write_string(fname, value, multiline)
144         elif os.path.isfile(fname):
145             os.remove(fname)
146
147     
148 class Patch(StgitObject):
149     """Basic patch implementation
150     """
151     def __init_refs(self):
152         self.__top_ref = self.__refs_base + '/' + self.__name
153         self.__log_ref = self.__top_ref + '.log'
154
155     def __init__(self, name, series_dir, refs_base):
156         self.__series_dir = series_dir
157         self.__name = name
158         self._set_dir(os.path.join(self.__series_dir, self.__name))
159         self.__refs_base = refs_base
160         self.__init_refs()
161
162     def create(self):
163         os.mkdir(self._dir())
164         self.create_empty_field('bottom')
165         self.create_empty_field('top')
166
167     def delete(self):
168         for f in os.listdir(self._dir()):
169             os.remove(os.path.join(self._dir(), f))
170         os.rmdir(self._dir())
171         git.delete_ref(self.__top_ref)
172         if git.ref_exists(self.__log_ref):
173             git.delete_ref(self.__log_ref)
174
175     def get_name(self):
176         return self.__name
177
178     def rename(self, newname):
179         olddir = self._dir()
180         old_top_ref = self.__top_ref
181         old_log_ref = self.__log_ref
182         self.__name = newname
183         self._set_dir(os.path.join(self.__series_dir, self.__name))
184         self.__init_refs()
185
186         git.rename_ref(old_top_ref, self.__top_ref)
187         if git.ref_exists(old_log_ref):
188             git.rename_ref(old_log_ref, self.__log_ref)
189         os.rename(olddir, self._dir())
190
191     def __update_top_ref(self, ref):
192         git.set_ref(self.__top_ref, ref)
193
194     def __update_log_ref(self, ref):
195         git.set_ref(self.__log_ref, ref)
196
197     def update_top_ref(self):
198         top = self.get_top()
199         if top:
200             self.__update_top_ref(top)
201
202     def get_old_bottom(self):
203         return self._get_field('bottom.old')
204
205     def get_bottom(self):
206         return self._get_field('bottom')
207
208     def set_bottom(self, value, backup = False):
209         if backup:
210             curr = self._get_field('bottom')
211             self._set_field('bottom.old', curr)
212         self._set_field('bottom', value)
213
214     def get_old_top(self):
215         return self._get_field('top.old')
216
217     def get_top(self):
218         return self._get_field('top')
219
220     def set_top(self, value, backup = False):
221         if backup:
222             curr = self._get_field('top')
223             self._set_field('top.old', curr)
224         self._set_field('top', value)
225         self.__update_top_ref(value)
226
227     def restore_old_boundaries(self):
228         bottom = self._get_field('bottom.old')
229         top = self._get_field('top.old')
230
231         if top and bottom:
232             self._set_field('bottom', bottom)
233             self._set_field('top', top)
234             self.__update_top_ref(top)
235             return True
236         else:
237             return False
238
239     def get_description(self):
240         return self._get_field('description', True)
241
242     def set_description(self, line):
243         self._set_field('description', line, True)
244
245     def get_authname(self):
246         return self._get_field('authname')
247
248     def set_authname(self, name):
249         self._set_field('authname', name or git.author().name)
250
251     def get_authemail(self):
252         return self._get_field('authemail')
253
254     def set_authemail(self, email):
255         self._set_field('authemail', email or git.author().email)
256
257     def get_authdate(self):
258         date = self._get_field('authdate')
259         if not date:
260             return date
261
262         if re.match('[0-9]+\s+[+-][0-9]+', date):
263             # Unix time (seconds) + time zone
264             secs_tz = date.split()
265             date = formatdate(int(secs_tz[0]))[:-5] + secs_tz[1]
266
267         return date
268
269     def set_authdate(self, date):
270         self._set_field('authdate', date or git.author().date)
271
272     def get_commname(self):
273         return self._get_field('commname')
274
275     def set_commname(self, name):
276         self._set_field('commname', name or git.committer().name)
277
278     def get_commemail(self):
279         return self._get_field('commemail')
280
281     def set_commemail(self, email):
282         self._set_field('commemail', email or git.committer().email)
283
284     def get_log(self):
285         return self._get_field('log')
286
287     def set_log(self, value, backup = False):
288         self._set_field('log', value)
289         self.__update_log_ref(value)
290
291 # The current StGIT metadata format version.
292 FORMAT_VERSION = 2
293
294 class PatchSet(StgitObject):
295     def __init__(self, name = None):
296         try:
297             if name:
298                 self.set_name (name)
299             else:
300                 self.set_name (git.get_head_file())
301             self.__base_dir = basedir.get()
302         except git.GitException, ex:
303             raise StackException, 'GIT tree not initialised: %s' % ex
304
305         self._set_dir(os.path.join(self.__base_dir, 'patches', self.get_name()))
306
307     def get_name(self):
308         return self.__name
309     def set_name(self, name):
310         self.__name = name
311
312     def _basedir(self):
313         return self.__base_dir
314
315     def get_head(self):
316         """Return the head of the branch
317         """
318         crt = self.get_current_patch()
319         if crt:
320             return crt.get_top()
321         else:
322             return self.get_base()
323
324     def get_protected(self):
325         return os.path.isfile(os.path.join(self._dir(), 'protected'))
326
327     def protect(self):
328         protect_file = os.path.join(self._dir(), 'protected')
329         if not os.path.isfile(protect_file):
330             create_empty_file(protect_file)
331
332     def unprotect(self):
333         protect_file = os.path.join(self._dir(), 'protected')
334         if os.path.isfile(protect_file):
335             os.remove(protect_file)
336
337     def __branch_descr(self):
338         return 'branch.%s.description' % self.get_name()
339
340     def get_description(self):
341         return config.get(self.__branch_descr()) or ''
342
343     def set_description(self, line):
344         if line:
345             config.set(self.__branch_descr(), line)
346         else:
347             config.unset(self.__branch_descr())
348
349     def head_top_equal(self):
350         """Return true if the head and the top are the same
351         """
352         crt = self.get_current_patch()
353         if not crt:
354             # we don't care, no patches applied
355             return True
356         return git.get_head() == crt.get_top()
357
358     def is_initialised(self):
359         """Checks if series is already initialised
360         """
361         return bool(config.get(self.format_version_key()))
362
363
364 def shortlog(patches):
365     log = ''.join(Run('git-log', '--pretty=short',
366                       p.get_top(), '^%s' % p.get_bottom()).raw_output()
367                   for p in patches)
368     return Run('git-shortlog').raw_input(log).raw_output()
369
370 class Series(PatchSet):
371     """Class including the operations on series
372     """
373     def __init__(self, name = None):
374         """Takes a series name as the parameter.
375         """
376         PatchSet.__init__(self, name)
377
378         # Update the branch to the latest format version if it is
379         # initialized, but don't touch it if it isn't.
380         self.update_to_current_format_version()
381
382         self.__refs_base = 'refs/patches/%s' % self.get_name()
383
384         self.__applied_file = os.path.join(self._dir(), 'applied')
385         self.__unapplied_file = os.path.join(self._dir(), 'unapplied')
386         self.__hidden_file = os.path.join(self._dir(), 'hidden')
387
388         # where this series keeps its patches
389         self.__patch_dir = os.path.join(self._dir(), 'patches')
390
391         # trash directory
392         self.__trash_dir = os.path.join(self._dir(), 'trash')
393
394     def format_version_key(self):
395         return 'branch.%s.stgit.stackformatversion' % self.get_name()
396
397     def update_to_current_format_version(self):
398         """Update a potentially older StGIT directory structure to the
399         latest version. Note: This function should depend as little as
400         possible on external functions that may change during a format
401         version bump, since it must remain able to process older formats."""
402
403         branch_dir = os.path.join(self._basedir(), 'patches', self.get_name())
404         def get_format_version():
405             """Return the integer format version number, or None if the
406             branch doesn't have any StGIT metadata at all, of any version."""
407             fv = config.get(self.format_version_key())
408             ofv = config.get('branch.%s.stgitformatversion' % self.get_name())
409             if fv:
410                 # Great, there's an explicitly recorded format version
411                 # number, which means that the branch is initialized and
412                 # of that exact version.
413                 return int(fv)
414             elif ofv:
415                 # Old name for the version info, upgrade it
416                 config.set(self.format_version_key(), ofv)
417                 config.unset('branch.%s.stgitformatversion' % self.get_name())
418                 return int(ofv)
419             elif os.path.isdir(os.path.join(branch_dir, 'patches')):
420                 # There's a .git/patches/<branch>/patches dirctory, which
421                 # means this is an initialized version 1 branch.
422                 return 1
423             elif os.path.isdir(branch_dir):
424                 # There's a .git/patches/<branch> directory, which means
425                 # this is an initialized version 0 branch.
426                 return 0
427             else:
428                 # The branch doesn't seem to be initialized at all.
429                 return None
430         def set_format_version(v):
431             out.info('Upgraded branch %s to format version %d' % (self.get_name(), v))
432             config.set(self.format_version_key(), '%d' % v)
433         def mkdir(d):
434             if not os.path.isdir(d):
435                 os.makedirs(d)
436         def rm(f):
437             if os.path.exists(f):
438                 os.remove(f)
439         def rm_ref(ref):
440             if git.ref_exists(ref):
441                 git.delete_ref(ref)
442
443         # Update 0 -> 1.
444         if get_format_version() == 0:
445             mkdir(os.path.join(branch_dir, 'trash'))
446             patch_dir = os.path.join(branch_dir, 'patches')
447             mkdir(patch_dir)
448             refs_base = 'refs/patches/%s' % self.get_name()
449             for patch in (file(os.path.join(branch_dir, 'unapplied')).readlines()
450                           + file(os.path.join(branch_dir, 'applied')).readlines()):
451                 patch = patch.strip()
452                 os.rename(os.path.join(branch_dir, patch),
453                           os.path.join(patch_dir, patch))
454                 Patch(patch, patch_dir, refs_base).update_top_ref()
455             set_format_version(1)
456
457         # Update 1 -> 2.
458         if get_format_version() == 1:
459             desc_file = os.path.join(branch_dir, 'description')
460             if os.path.isfile(desc_file):
461                 desc = read_string(desc_file)
462                 if desc:
463                     config.set('branch.%s.description' % self.get_name(), desc)
464                 rm(desc_file)
465             rm(os.path.join(branch_dir, 'current'))
466             rm_ref('refs/bases/%s' % self.get_name())
467             set_format_version(2)
468
469         # Make sure we're at the latest version.
470         if not get_format_version() in [None, FORMAT_VERSION]:
471             raise StackException('Branch %s is at format version %d, expected %d'
472                                  % (self.get_name(), get_format_version(), FORMAT_VERSION))
473
474     def __patch_name_valid(self, name):
475         """Raise an exception if the patch name is not valid.
476         """
477         if not name or re.search('[^\w.-]', name):
478             raise StackException, 'Invalid patch name: "%s"' % name
479
480     def get_patch(self, name):
481         """Return a Patch object for the given name
482         """
483         return Patch(name, self.__patch_dir, self.__refs_base)
484
485     def get_current_patch(self):
486         """Return a Patch object representing the topmost patch, or
487         None if there is no such patch."""
488         crt = self.get_current()
489         if not crt:
490             return None
491         return self.get_patch(crt)
492
493     def get_current(self):
494         """Return the name of the topmost patch, or None if there is
495         no such patch."""
496         try:
497             applied = self.get_applied()
498         except StackException:
499             # No "applied" file: branch is not initialized.
500             return None
501         try:
502             return applied[-1]
503         except IndexError:
504             # No patches applied.
505             return None
506
507     def get_applied(self):
508         if not os.path.isfile(self.__applied_file):
509             raise StackException, 'Branch "%s" not initialised' % self.get_name()
510         return read_strings(self.__applied_file)
511
512     def set_applied(self, applied):
513         write_strings(self.__applied_file, applied)
514
515     def get_unapplied(self):
516         if not os.path.isfile(self.__unapplied_file):
517             raise StackException, 'Branch "%s" not initialised' % self.get_name()
518         return read_strings(self.__unapplied_file)
519
520     def set_unapplied(self, unapplied):
521         write_strings(self.__unapplied_file, unapplied)
522
523     def get_hidden(self):
524         if not os.path.isfile(self.__hidden_file):
525             return []
526         return read_strings(self.__hidden_file)
527
528     def get_base(self):
529         # Return the parent of the bottommost patch, if there is one.
530         if os.path.isfile(self.__applied_file):
531             bottommost = file(self.__applied_file).readline().strip()
532             if bottommost:
533                 return self.get_patch(bottommost).get_bottom()
534         # No bottommost patch, so just return HEAD
535         return git.get_head()
536
537     def get_parent_remote(self):
538         value = config.get('branch.%s.remote' % self.get_name())
539         if value:
540             return value
541         elif 'origin' in git.remotes_list():
542             out.note(('No parent remote declared for stack "%s",'
543                       ' defaulting to "origin".' % self.get_name()),
544                      ('Consider setting "branch.%s.remote" and'
545                       ' "branch.%s.merge" with "git config".'
546                       % (self.get_name(), self.get_name())))
547             return 'origin'
548         else:
549             raise StackException, 'Cannot find a parent remote for "%s"' % self.get_name()
550
551     def __set_parent_remote(self, remote):
552         value = config.set('branch.%s.remote' % self.get_name(), remote)
553
554     def get_parent_branch(self):
555         value = config.get('branch.%s.stgit.parentbranch' % self.get_name())
556         if value:
557             return value
558         elif git.rev_parse('heads/origin'):
559             out.note(('No parent branch declared for stack "%s",'
560                       ' defaulting to "heads/origin".' % self.get_name()),
561                      ('Consider setting "branch.%s.stgit.parentbranch"'
562                       ' with "git config".' % self.get_name()))
563             return 'heads/origin'
564         else:
565             raise StackException, 'Cannot find a parent branch for "%s"' % self.get_name()
566
567     def __set_parent_branch(self, name):
568         if config.get('branch.%s.remote' % self.get_name()):
569             # Never set merge if remote is not set to avoid
570             # possibly-erroneous lookups into 'origin'
571             config.set('branch.%s.merge' % self.get_name(), name)
572         config.set('branch.%s.stgit.parentbranch' % self.get_name(), name)
573
574     def set_parent(self, remote, localbranch):
575         if localbranch:
576             if remote:
577                 self.__set_parent_remote(remote)
578             self.__set_parent_branch(localbranch)
579         # We'll enforce this later
580 #         else:
581 #             raise StackException, 'Parent branch (%s) should be specified for %s' % localbranch, self.get_name()
582
583     def __patch_is_current(self, patch):
584         return patch.get_name() == self.get_current()
585
586     def patch_applied(self, name):
587         """Return true if the patch exists in the applied list
588         """
589         return name in self.get_applied()
590
591     def patch_unapplied(self, name):
592         """Return true if the patch exists in the unapplied list
593         """
594         return name in self.get_unapplied()
595
596     def patch_hidden(self, name):
597         """Return true if the patch is hidden.
598         """
599         return name in self.get_hidden()
600
601     def patch_exists(self, name):
602         """Return true if there is a patch with the given name, false
603         otherwise."""
604         return self.patch_applied(name) or self.patch_unapplied(name) \
605                or self.patch_hidden(name)
606
607     def init(self, create_at=False, parent_remote=None, parent_branch=None):
608         """Initialises the stgit series
609         """
610         if self.is_initialised():
611             raise StackException, '%s already initialized' % self.get_name()
612         for d in [self._dir()]:
613             if os.path.exists(d):
614                 raise StackException, '%s already exists' % d
615
616         if (create_at!=False):
617             git.create_branch(self.get_name(), create_at)
618
619         os.makedirs(self.__patch_dir)
620
621         self.set_parent(parent_remote, parent_branch)
622
623         self.create_empty_field('applied')
624         self.create_empty_field('unapplied')
625         self._set_field('orig-base', git.get_head())
626
627         config.set(self.format_version_key(), str(FORMAT_VERSION))
628
629     def rename(self, to_name):
630         """Renames a series
631         """
632         to_stack = Series(to_name)
633
634         if to_stack.is_initialised():
635             raise StackException, '"%s" already exists' % to_stack.get_name()
636
637         patches = self.get_applied() + self.get_unapplied()
638
639         git.rename_branch(self.get_name(), to_name)
640
641         for patch in patches:
642             git.rename_ref('refs/patches/%s/%s' % (self.get_name(), patch),
643                            'refs/patches/%s/%s' % (to_name, patch))
644             git.rename_ref('refs/patches/%s/%s.log' % (self.get_name(), patch),
645                            'refs/patches/%s/%s.log' % (to_name, patch))
646         if os.path.isdir(self._dir()):
647             rename(os.path.join(self._basedir(), 'patches'),
648                    self.get_name(), to_stack.get_name())
649
650         # Rename the config section
651         for k in ['branch.%s', 'branch.%s.stgit']:
652             config.rename_section(k % self.get_name(), k % to_name)
653
654         self.__init__(to_name)
655
656     def clone(self, target_series):
657         """Clones a series
658         """
659         try:
660             # allow cloning of branches not under StGIT control
661             base = self.get_base()
662         except:
663             base = git.get_head()
664         Series(target_series).init(create_at = base)
665         new_series = Series(target_series)
666
667         # generate an artificial description file
668         new_series.set_description('clone of "%s"' % self.get_name())
669
670         # clone self's entire series as unapplied patches
671         try:
672             # allow cloning of branches not under StGIT control
673             applied = self.get_applied()
674             unapplied = self.get_unapplied()
675             patches = applied + unapplied
676             patches.reverse()
677         except:
678             patches = applied = unapplied = []
679         for p in patches:
680             patch = self.get_patch(p)
681             newpatch = new_series.new_patch(p, message = patch.get_description(),
682                                             can_edit = False, unapplied = True,
683                                             bottom = patch.get_bottom(),
684                                             top = patch.get_top(),
685                                             author_name = patch.get_authname(),
686                                             author_email = patch.get_authemail(),
687                                             author_date = patch.get_authdate())
688             if patch.get_log():
689                 out.info('Setting log to %s' %  patch.get_log())
690                 newpatch.set_log(patch.get_log())
691             else:
692                 out.info('No log for %s' % p)
693
694         # fast forward the cloned series to self's top
695         new_series.forward_patches(applied)
696
697         # Clone parent informations
698         value = config.get('branch.%s.remote' % self.get_name())
699         if value:
700             config.set('branch.%s.remote' % target_series, value)
701
702         value = config.get('branch.%s.merge' % self.get_name())
703         if value:
704             config.set('branch.%s.merge' % target_series, value)
705
706         value = config.get('branch.%s.stgit.parentbranch' % self.get_name())
707         if value:
708             config.set('branch.%s.stgit.parentbranch' % target_series, value)
709
710     def delete(self, force = False):
711         """Deletes an stgit series
712         """
713         if self.is_initialised():
714             patches = self.get_unapplied() + self.get_applied()
715             if not force and patches:
716                 raise StackException, \
717                       'Cannot delete: the series still contains patches'
718             for p in patches:
719                 self.get_patch(p).delete()
720
721             # remove the trash directory if any
722             if os.path.exists(self.__trash_dir):
723                 for fname in os.listdir(self.__trash_dir):
724                     os.remove(os.path.join(self.__trash_dir, fname))
725                 os.rmdir(self.__trash_dir)
726
727             # FIXME: find a way to get rid of those manual removals
728             # (move functionality to StgitObject ?)
729             if os.path.exists(self.__applied_file):
730                 os.remove(self.__applied_file)
731             if os.path.exists(self.__unapplied_file):
732                 os.remove(self.__unapplied_file)
733             if os.path.exists(self.__hidden_file):
734                 os.remove(self.__hidden_file)
735             if os.path.exists(self._dir()+'/orig-base'):
736                 os.remove(self._dir()+'/orig-base')
737
738             if not os.listdir(self.__patch_dir):
739                 os.rmdir(self.__patch_dir)
740             else:
741                 out.warn('Patch directory %s is not empty' % self.__patch_dir)
742
743             try:
744                 os.removedirs(self._dir())
745             except OSError:
746                 raise StackException('Series directory %s is not empty'
747                                      % self._dir())
748
749             try:
750                 git.delete_branch(self.get_name())
751             except GitException:
752                 out.warn('Could not delete branch "%s"' % self.get_name())
753
754         config.remove_section('branch.%s' % self.get_name())
755         config.remove_section('branch.%s.stgit' % self.get_name())
756
757     def refresh_patch(self, files = None, message = None, edit = False,
758                       show_patch = False,
759                       cache_update = True,
760                       author_name = None, author_email = None,
761                       author_date = None,
762                       committer_name = None, committer_email = None,
763                       backup = False, sign_str = None, log = 'refresh',
764                       notes = None, bottom = None):
765         """Generates a new commit for the topmost patch
766         """
767         patch = self.get_current_patch()
768         if not patch:
769             raise StackException, 'No patches applied'
770
771         descr = patch.get_description()
772         if not (message or descr):
773             edit = True
774             descr = ''
775         elif message:
776             descr = message
777
778         # TODO: move this out of the stgit.stack module, it is really
779         # for higher level commands to handle the user interaction
780         if not message and edit:
781             descr = edit_file(self, descr.rstrip(), \
782                               'Please edit the description for patch "%s" ' \
783                               'above.' % patch.get_name(), show_patch)
784
785         if not author_name:
786             author_name = patch.get_authname()
787         if not author_email:
788             author_email = patch.get_authemail()
789         if not author_date:
790             author_date = patch.get_authdate()
791         if not committer_name:
792             committer_name = patch.get_commname()
793         if not committer_email:
794             committer_email = patch.get_commemail()
795
796         descr = add_sign_line(descr, sign_str, committer_name, committer_email)
797
798         if not bottom:
799             bottom = patch.get_bottom()
800
801         commit_id = git.commit(files = files,
802                                message = descr, parents = [bottom],
803                                cache_update = cache_update,
804                                allowempty = True,
805                                author_name = author_name,
806                                author_email = author_email,
807                                author_date = author_date,
808                                committer_name = committer_name,
809                                committer_email = committer_email)
810
811         patch.set_bottom(bottom, backup = backup)
812         patch.set_top(commit_id, backup = backup)
813         patch.set_description(descr)
814         patch.set_authname(author_name)
815         patch.set_authemail(author_email)
816         patch.set_authdate(author_date)
817         patch.set_commname(committer_name)
818         patch.set_commemail(committer_email)
819
820         if log:
821             self.log_patch(patch, log, notes)
822
823         return commit_id
824
825     def undo_refresh(self):
826         """Undo the patch boundaries changes caused by 'refresh'
827         """
828         name = self.get_current()
829         assert(name)
830
831         patch = self.get_patch(name)
832         old_bottom = patch.get_old_bottom()
833         old_top = patch.get_old_top()
834
835         # the bottom of the patch is not changed by refresh. If the
836         # old_bottom is different, there wasn't any previous 'refresh'
837         # command (probably only a 'push')
838         if old_bottom != patch.get_bottom() or old_top == patch.get_top():
839             raise StackException, 'No undo information available'
840
841         git.reset(tree_id = old_top, check_out = False)
842         if patch.restore_old_boundaries():
843             self.log_patch(patch, 'undo')
844
845     def new_patch(self, name, message = None, can_edit = True,
846                   unapplied = False, show_patch = False,
847                   top = None, bottom = None, commit = True,
848                   author_name = None, author_email = None, author_date = None,
849                   committer_name = None, committer_email = None,
850                   before_existing = False):
851         """Creates a new patch, either pointing to an existing commit object,
852         or by creating a new commit object.
853         """
854
855         assert commit or (top and bottom)
856         assert not before_existing or (top and bottom)
857         assert not (commit and before_existing)
858         assert (top and bottom) or (not top and not bottom)
859         assert not top or (bottom == git.get_commit(top).get_parent())
860
861         if name != None:
862             self.__patch_name_valid(name)
863             if self.patch_exists(name):
864                 raise StackException, 'Patch "%s" already exists' % name
865
866         # TODO: move this out of the stgit.stack module, it is really
867         # for higher level commands to handle the user interaction
868         if not message and can_edit:
869             descr = edit_file(
870                 self, None,
871                 'Please enter the description for the patch above.',
872                 show_patch)
873         else:
874             descr = message
875
876         head = git.get_head()
877
878         if name == None:
879             name = make_patch_name(descr, self.patch_exists)
880
881         patch = self.get_patch(name)
882         patch.create()
883
884         patch.set_description(descr)
885         patch.set_authname(author_name)
886         patch.set_authemail(author_email)
887         patch.set_authdate(author_date)
888         patch.set_commname(committer_name)
889         patch.set_commemail(committer_email)
890
891         if before_existing:
892             insert_string(self.__applied_file, patch.get_name())
893         elif unapplied:
894             patches = [patch.get_name()] + self.get_unapplied()
895             write_strings(self.__unapplied_file, patches)
896             set_head = False
897         else:
898             append_string(self.__applied_file, patch.get_name())
899             set_head = True
900
901         if commit:
902             if top:
903                 top_commit = git.get_commit(top)
904             else:
905                 bottom = head
906                 top_commit = git.get_commit(head)
907
908             # create a commit for the patch (may be empty if top == bottom);
909             # only commit on top of the current branch
910             assert(unapplied or bottom == head)
911             commit_id = git.commit(message = descr, parents = [bottom],
912                                    cache_update = False,
913                                    tree_id = top_commit.get_tree(),
914                                    allowempty = True, set_head = set_head,
915                                    author_name = author_name,
916                                    author_email = author_email,
917                                    author_date = author_date,
918                                    committer_name = committer_name,
919                                    committer_email = committer_email)
920             # set the patch top to the new commit
921             patch.set_bottom(bottom)
922             patch.set_top(commit_id)
923         else:
924             assert top != bottom
925             patch.set_bottom(bottom)
926             patch.set_top(top)
927
928         self.log_patch(patch, 'new')
929
930         return patch
931
932     def delete_patch(self, name):
933         """Deletes a patch
934         """
935         self.__patch_name_valid(name)
936         patch = self.get_patch(name)
937
938         if self.__patch_is_current(patch):
939             self.pop_patch(name)
940         elif self.patch_applied(name):
941             raise StackException, 'Cannot remove an applied patch, "%s", ' \
942                   'which is not current' % name
943         elif not name in self.get_unapplied():
944             raise StackException, 'Unknown patch "%s"' % name
945
946         # save the commit id to a trash file
947         write_string(os.path.join(self.__trash_dir, name), patch.get_top())
948
949         patch.delete()
950
951         unapplied = self.get_unapplied()
952         unapplied.remove(name)
953         write_strings(self.__unapplied_file, unapplied)
954
955     def forward_patches(self, names):
956         """Try to fast-forward an array of patches.
957
958         On return, patches in names[0:returned_value] have been pushed on the
959         stack. Apply the rest with push_patch
960         """
961         unapplied = self.get_unapplied()
962
963         forwarded = 0
964         top = git.get_head()
965
966         for name in names:
967             assert(name in unapplied)
968
969             patch = self.get_patch(name)
970
971             head = top
972             bottom = patch.get_bottom()
973             top = patch.get_top()
974
975             # top != bottom always since we have a commit for each patch
976             if head == bottom:
977                 # reset the backup information. No logging since the
978                 # patch hasn't changed
979                 patch.set_bottom(head, backup = True)
980                 patch.set_top(top, backup = True)
981
982             else:
983                 head_tree = git.get_commit(head).get_tree()
984                 bottom_tree = git.get_commit(bottom).get_tree()
985                 if head_tree == bottom_tree:
986                     # We must just reparent this patch and create a new commit
987                     # for it
988                     descr = patch.get_description()
989                     author_name = patch.get_authname()
990                     author_email = patch.get_authemail()
991                     author_date = patch.get_authdate()
992                     committer_name = patch.get_commname()
993                     committer_email = patch.get_commemail()
994
995                     top_tree = git.get_commit(top).get_tree()
996
997                     top = git.commit(message = descr, parents = [head],
998                                      cache_update = False,
999                                      tree_id = top_tree,
1000                                      allowempty = True,
1001                                      author_name = author_name,
1002                                      author_email = author_email,
1003                                      author_date = author_date,
1004                                      committer_name = committer_name,
1005                                      committer_email = committer_email)
1006
1007                     patch.set_bottom(head, backup = True)
1008                     patch.set_top(top, backup = True)
1009
1010                     self.log_patch(patch, 'push(f)')
1011                 else:
1012                     top = head
1013                     # stop the fast-forwarding, must do a real merge
1014                     break
1015
1016             forwarded+=1
1017             unapplied.remove(name)
1018
1019         if forwarded == 0:
1020             return 0
1021
1022         git.switch(top)
1023
1024         append_strings(self.__applied_file, names[0:forwarded])
1025         write_strings(self.__unapplied_file, unapplied)
1026
1027         return forwarded
1028
1029     def merged_patches(self, names):
1030         """Test which patches were merged upstream by reverse-applying
1031         them in reverse order. The function returns the list of
1032         patches detected to have been applied. The state of the tree
1033         is restored to the original one
1034         """
1035         patches = [self.get_patch(name) for name in names]
1036         patches.reverse()
1037
1038         merged = []
1039         for p in patches:
1040             if git.apply_diff(p.get_top(), p.get_bottom()):
1041                 merged.append(p.get_name())
1042         merged.reverse()
1043
1044         git.reset()
1045
1046         return merged
1047
1048     def push_empty_patch(self, name):
1049         """Pushes an empty patch on the stack
1050         """
1051         unapplied = self.get_unapplied()
1052         assert(name in unapplied)
1053
1054         # patch = self.get_patch(name)
1055         head = git.get_head()
1056
1057         append_string(self.__applied_file, name)
1058
1059         unapplied.remove(name)
1060         write_strings(self.__unapplied_file, unapplied)
1061
1062         self.refresh_patch(bottom = head, cache_update = False, log = 'push(m)')
1063
1064     def push_patch(self, name):
1065         """Pushes a patch on the stack
1066         """
1067         unapplied = self.get_unapplied()
1068         assert(name in unapplied)
1069
1070         patch = self.get_patch(name)
1071
1072         head = git.get_head()
1073         bottom = patch.get_bottom()
1074         top = patch.get_top()
1075         # top != bottom always since we have a commit for each patch
1076
1077         if head == bottom:
1078             # A fast-forward push. Just reset the backup
1079             # information. No need for logging
1080             patch.set_bottom(bottom, backup = True)
1081             patch.set_top(top, backup = True)
1082
1083             git.switch(top)
1084             append_string(self.__applied_file, name)
1085
1086             unapplied.remove(name)
1087             write_strings(self.__unapplied_file, unapplied)
1088             return False
1089
1090         # Need to create a new commit an merge in the old patch
1091         ex = None
1092         modified = False
1093
1094         # Try the fast applying first. If this fails, fall back to the
1095         # three-way merge
1096         if not git.apply_diff(bottom, top):
1097             # if git.apply_diff() fails, the patch requires a diff3
1098             # merge and can be reported as modified
1099             modified = True
1100
1101             # merge can fail but the patch needs to be pushed
1102             try:
1103                 git.merge(bottom, head, top, recursive = True)
1104             except git.GitException, ex:
1105                 out.error('The merge failed during "push".',
1106                           'Use "refresh" after fixing the conflicts or'
1107                           ' revert the operation with "push --undo".')
1108
1109         append_string(self.__applied_file, name)
1110
1111         unapplied.remove(name)
1112         write_strings(self.__unapplied_file, unapplied)
1113
1114         if not ex:
1115             # if the merge was OK and no conflicts, just refresh the patch
1116             # The GIT cache was already updated by the merge operation
1117             if modified:
1118                 log = 'push(m)'
1119             else:
1120                 log = 'push'
1121             self.refresh_patch(bottom = head, cache_update = False, log = log)
1122         else:
1123             # we store the correctly merged files only for
1124             # tracking the conflict history. Note that the
1125             # git.merge() operations should always leave the index
1126             # in a valid state (i.e. only stage 0 files)
1127             self.refresh_patch(bottom = head, cache_update = False,
1128                                log = 'push(c)')
1129             raise StackException, str(ex)
1130
1131         return modified
1132
1133     def undo_push(self):
1134         name = self.get_current()
1135         assert(name)
1136
1137         patch = self.get_patch(name)
1138         old_bottom = patch.get_old_bottom()
1139         old_top = patch.get_old_top()
1140
1141         # the top of the patch is changed by a push operation only
1142         # together with the bottom (otherwise the top was probably
1143         # modified by 'refresh'). If they are both unchanged, there
1144         # was a fast forward
1145         if old_bottom == patch.get_bottom() and old_top != patch.get_top():
1146             raise StackException, 'No undo information available'
1147
1148         git.reset()
1149         self.pop_patch(name)
1150         ret = patch.restore_old_boundaries()
1151         if ret:
1152             self.log_patch(patch, 'undo')
1153
1154         return ret
1155
1156     def pop_patch(self, name, keep = False):
1157         """Pops the top patch from the stack
1158         """
1159         applied = self.get_applied()
1160         applied.reverse()
1161         assert(name in applied)
1162
1163         patch = self.get_patch(name)
1164
1165         if git.get_head_file() == self.get_name():
1166             if keep and not git.apply_diff(git.get_head(), patch.get_bottom()):
1167                 raise StackException(
1168                     'Failed to pop patches while preserving the local changes')
1169             git.switch(patch.get_bottom(), keep)
1170         else:
1171             git.set_branch(self.get_name(), patch.get_bottom())
1172
1173         # save the new applied list
1174         idx = applied.index(name) + 1
1175
1176         popped = applied[:idx]
1177         popped.reverse()
1178         unapplied = popped + self.get_unapplied()
1179         write_strings(self.__unapplied_file, unapplied)
1180
1181         del applied[:idx]
1182         applied.reverse()
1183         write_strings(self.__applied_file, applied)
1184
1185     def empty_patch(self, name):
1186         """Returns True if the patch is empty
1187         """
1188         self.__patch_name_valid(name)
1189         patch = self.get_patch(name)
1190         bottom = patch.get_bottom()
1191         top = patch.get_top()
1192
1193         if bottom == top:
1194             return True
1195         elif git.get_commit(top).get_tree() \
1196                  == git.get_commit(bottom).get_tree():
1197             return True
1198
1199         return False
1200
1201     def rename_patch(self, oldname, newname):
1202         self.__patch_name_valid(newname)
1203
1204         applied = self.get_applied()
1205         unapplied = self.get_unapplied()
1206
1207         if oldname == newname:
1208             raise StackException, '"To" name and "from" name are the same'
1209
1210         if newname in applied or newname in unapplied:
1211             raise StackException, 'Patch "%s" already exists' % newname
1212
1213         if oldname in unapplied:
1214             self.get_patch(oldname).rename(newname)
1215             unapplied[unapplied.index(oldname)] = newname
1216             write_strings(self.__unapplied_file, unapplied)
1217         elif oldname in applied:
1218             self.get_patch(oldname).rename(newname)
1219
1220             applied[applied.index(oldname)] = newname
1221             write_strings(self.__applied_file, applied)
1222         else:
1223             raise StackException, 'Unknown patch "%s"' % oldname
1224
1225     def log_patch(self, patch, message, notes = None):
1226         """Generate a log commit for a patch
1227         """
1228         top = git.get_commit(patch.get_top())
1229         old_log = patch.get_log()
1230
1231         if message is None:
1232             # replace the current log entry
1233             if not old_log:
1234                 raise StackException, \
1235                       'No log entry to annotate for patch "%s"' \
1236                       % patch.get_name()
1237             replace = True
1238             log_commit = git.get_commit(old_log)
1239             msg = log_commit.get_log().split('\n')[0]
1240             log_parent = log_commit.get_parent()
1241             if log_parent:
1242                 parents = [log_parent]
1243             else:
1244                 parents = []
1245         else:
1246             # generate a new log entry
1247             replace = False
1248             msg = '%s\t%s' % (message, top.get_id_hash())
1249             if old_log:
1250                 parents = [old_log]
1251             else:
1252                 parents = []
1253
1254         if notes:
1255             msg += '\n\n' + notes
1256
1257         log = git.commit(message = msg, parents = parents,
1258                          cache_update = False, tree_id = top.get_tree(),
1259                          allowempty = True)
1260         patch.set_log(log)
1261
1262     def hide_patch(self, name):
1263         """Add the patch to the hidden list.
1264         """
1265         unapplied = self.get_unapplied()
1266         if name not in unapplied:
1267             # keep the checking order for backward compatibility with
1268             # the old hidden patches functionality
1269             if self.patch_applied(name):
1270                 raise StackException, 'Cannot hide applied patch "%s"' % name
1271             elif self.patch_hidden(name):
1272                 raise StackException, 'Patch "%s" already hidden' % name
1273             else:
1274                 raise StackException, 'Unknown patch "%s"' % name
1275
1276         if not self.patch_hidden(name):
1277             # check needed for backward compatibility with the old
1278             # hidden patches functionality
1279             append_string(self.__hidden_file, name)
1280
1281         unapplied.remove(name)
1282         write_strings(self.__unapplied_file, unapplied)
1283
1284     def unhide_patch(self, name):
1285         """Remove the patch from the hidden list.
1286         """
1287         hidden = self.get_hidden()
1288         if not name in hidden:
1289             if self.patch_applied(name) or self.patch_unapplied(name):
1290                 raise StackException, 'Patch "%s" not hidden' % name
1291             else:
1292                 raise StackException, 'Unknown patch "%s"' % name
1293
1294         hidden.remove(name)
1295         write_strings(self.__hidden_file, hidden)
1296
1297         if not self.patch_applied(name) and not self.patch_unapplied(name):
1298             # check needed for backward compatibility with the old
1299             # hidden patches functionality
1300             append_string(self.__unapplied_file, name)