chiark / gitweb /
Add --ack/--sign options to "stg new"
[stgit] / stgit / stack.py
1 """Basic quilt-like functionality
2 """
3
4 __copyright__ = """
5 Copyright (C) 2005, Catalin Marinas <catalin.marinas@gmail.com>
6
7 This program is free software; you can redistribute it and/or modify
8 it under the terms of the GNU General Public License version 2 as
9 published by the Free Software Foundation.
10
11 This program is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14 GNU General Public License for more details.
15
16 You should have received a copy of the GNU General Public License
17 along with this program; if not, write to the Free Software
18 Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
19 """
20
21 import sys, os, re
22 from email.Utils import formatdate
23
24 from stgit.utils import *
25 from stgit.out import *
26 from stgit.run import *
27 from stgit import git, basedir, templates
28 from stgit.config import config
29 from shutil import copyfile
30
31
32 # stack exception class
33 class StackException(Exception):
34     pass
35
36 class FilterUntil:
37     def __init__(self):
38         self.should_print = True
39     def __call__(self, x, until_test, prefix):
40         if until_test(x):
41             self.should_print = False
42         if self.should_print:
43             return x[0:len(prefix)] != prefix
44         return False
45
46 #
47 # Functions
48 #
49 __comment_prefix = 'STG:'
50 __patch_prefix = 'STG_PATCH:'
51
52 def __clean_comments(f):
53     """Removes lines marked for status in a commit file
54     """
55     f.seek(0)
56
57     # remove status-prefixed lines
58     lines = f.readlines()
59
60     patch_filter = FilterUntil()
61     until_test = lambda t: t == (__patch_prefix + '\n')
62     lines = [l for l in lines if patch_filter(l, until_test, __comment_prefix)]
63
64     # remove empty lines at the end
65     while len(lines) != 0 and lines[-1] == '\n':
66         del lines[-1]
67
68     f.seek(0); f.truncate()
69     f.writelines(lines)
70
71 # TODO: move this out of the stgit.stack module, it is really for
72 # higher level commands to handle the user interaction
73 def edit_file(series, line, comment, show_patch = True):
74     fname = '.stgitmsg.txt'
75     tmpl = templates.get_template('patchdescr.tmpl')
76
77     f = file(fname, 'w+')
78     if line:
79         print >> f, line
80     elif tmpl:
81         print >> f, tmpl,
82     else:
83         print >> f
84     print >> f, __comment_prefix, comment
85     print >> f, __comment_prefix, \
86           'Lines prefixed with "%s" will be automatically removed.' \
87           % __comment_prefix
88     print >> f, __comment_prefix, \
89           'Trailing empty lines will be automatically removed.'
90
91     if show_patch:
92        print >> f, __patch_prefix
93        # series.get_patch(series.get_current()).get_top()
94        diff_str = git.diff(rev1 = series.get_patch(series.get_current()).get_bottom())
95        f.write(diff_str)
96
97     #Vim modeline must be near the end.
98     print >> f, __comment_prefix, 'vi: set textwidth=75 filetype=diff nobackup:'
99     f.close()
100
101     call_editor(fname)
102
103     f = file(fname, 'r+')
104
105     __clean_comments(f)
106     f.seek(0)
107     result = f.read()
108
109     f.close()
110     os.remove(fname)
111
112     return result
113
114 #
115 # Classes
116 #
117
118 class StgitObject:
119     """An object with stgit-like properties stored as files in a directory
120     """
121     def _set_dir(self, dir):
122         self.__dir = dir
123     def _dir(self):
124         return self.__dir
125
126     def create_empty_field(self, name):
127         create_empty_file(os.path.join(self.__dir, name))
128
129     def _get_field(self, name, multiline = False):
130         id_file = os.path.join(self.__dir, name)
131         if os.path.isfile(id_file):
132             line = read_string(id_file, multiline)
133             if line == '':
134                 return None
135             else:
136                 return line
137         else:
138             return None
139
140     def _set_field(self, name, value, multiline = False):
141         fname = os.path.join(self.__dir, name)
142         if value and value != '':
143             write_string(fname, value, multiline)
144         elif os.path.isfile(fname):
145             os.remove(fname)
146
147     
148 class Patch(StgitObject):
149     """Basic patch implementation
150     """
151     def __init_refs(self):
152         self.__top_ref = self.__refs_base + '/' + self.__name
153         self.__log_ref = self.__top_ref + '.log'
154
155     def __init__(self, name, series_dir, refs_base):
156         self.__series_dir = series_dir
157         self.__name = name
158         self._set_dir(os.path.join(self.__series_dir, self.__name))
159         self.__refs_base = refs_base
160         self.__init_refs()
161
162     def create(self):
163         os.mkdir(self._dir())
164         self.create_empty_field('bottom')
165         self.create_empty_field('top')
166
167     def delete(self):
168         for f in os.listdir(self._dir()):
169             os.remove(os.path.join(self._dir(), f))
170         os.rmdir(self._dir())
171         git.delete_ref(self.__top_ref)
172         if git.ref_exists(self.__log_ref):
173             git.delete_ref(self.__log_ref)
174
175     def get_name(self):
176         return self.__name
177
178     def rename(self, newname):
179         olddir = self._dir()
180         old_top_ref = self.__top_ref
181         old_log_ref = self.__log_ref
182         self.__name = newname
183         self._set_dir(os.path.join(self.__series_dir, self.__name))
184         self.__init_refs()
185
186         git.rename_ref(old_top_ref, self.__top_ref)
187         if git.ref_exists(old_log_ref):
188             git.rename_ref(old_log_ref, self.__log_ref)
189         os.rename(olddir, self._dir())
190
191     def __update_top_ref(self, ref):
192         git.set_ref(self.__top_ref, ref)
193
194     def __update_log_ref(self, ref):
195         git.set_ref(self.__log_ref, ref)
196
197     def update_top_ref(self):
198         top = self.get_top()
199         if top:
200             self.__update_top_ref(top)
201
202     def get_old_bottom(self):
203         return self._get_field('bottom.old')
204
205     def get_bottom(self):
206         return self._get_field('bottom')
207
208     def set_bottom(self, value, backup = False):
209         if backup:
210             curr = self._get_field('bottom')
211             self._set_field('bottom.old', curr)
212         self._set_field('bottom', value)
213
214     def get_old_top(self):
215         return self._get_field('top.old')
216
217     def get_top(self):
218         return self._get_field('top')
219
220     def set_top(self, value, backup = False):
221         if backup:
222             curr = self._get_field('top')
223             self._set_field('top.old', curr)
224         self._set_field('top', value)
225         self.__update_top_ref(value)
226
227     def restore_old_boundaries(self):
228         bottom = self._get_field('bottom.old')
229         top = self._get_field('top.old')
230
231         if top and bottom:
232             self._set_field('bottom', bottom)
233             self._set_field('top', top)
234             self.__update_top_ref(top)
235             return True
236         else:
237             return False
238
239     def get_description(self):
240         return self._get_field('description', True)
241
242     def set_description(self, line):
243         self._set_field('description', line, True)
244
245     def get_authname(self):
246         return self._get_field('authname')
247
248     def set_authname(self, name):
249         self._set_field('authname', name or git.author().name)
250
251     def get_authemail(self):
252         return self._get_field('authemail')
253
254     def set_authemail(self, email):
255         self._set_field('authemail', email or git.author().email)
256
257     def get_authdate(self):
258         date = self._get_field('authdate')
259         if not date:
260             return date
261
262         if re.match('[0-9]+\s+[+-][0-9]+', date):
263             # Unix time (seconds) + time zone
264             secs_tz = date.split()
265             date = formatdate(int(secs_tz[0]))[:-5] + secs_tz[1]
266
267         return date
268
269     def set_authdate(self, date):
270         self._set_field('authdate', date or git.author().date)
271
272     def get_commname(self):
273         return self._get_field('commname')
274
275     def set_commname(self, name):
276         self._set_field('commname', name or git.committer().name)
277
278     def get_commemail(self):
279         return self._get_field('commemail')
280
281     def set_commemail(self, email):
282         self._set_field('commemail', email or git.committer().email)
283
284     def get_log(self):
285         return self._get_field('log')
286
287     def set_log(self, value, backup = False):
288         self._set_field('log', value)
289         self.__update_log_ref(value)
290
291 # The current StGIT metadata format version.
292 FORMAT_VERSION = 2
293
294 class PatchSet(StgitObject):
295     def __init__(self, name = None):
296         try:
297             if name:
298                 self.set_name (name)
299             else:
300                 self.set_name (git.get_head_file())
301             self.__base_dir = basedir.get()
302         except git.GitException, ex:
303             raise StackException, 'GIT tree not initialised: %s' % ex
304
305         self._set_dir(os.path.join(self.__base_dir, 'patches', self.get_name()))
306
307     def get_name(self):
308         return self.__name
309     def set_name(self, name):
310         self.__name = name
311
312     def _basedir(self):
313         return self.__base_dir
314
315     def get_head(self):
316         """Return the head of the branch
317         """
318         crt = self.get_current_patch()
319         if crt:
320             return crt.get_top()
321         else:
322             return self.get_base()
323
324     def get_protected(self):
325         return os.path.isfile(os.path.join(self._dir(), 'protected'))
326
327     def protect(self):
328         protect_file = os.path.join(self._dir(), 'protected')
329         if not os.path.isfile(protect_file):
330             create_empty_file(protect_file)
331
332     def unprotect(self):
333         protect_file = os.path.join(self._dir(), 'protected')
334         if os.path.isfile(protect_file):
335             os.remove(protect_file)
336
337     def __branch_descr(self):
338         return 'branch.%s.description' % self.get_name()
339
340     def get_description(self):
341         return config.get(self.__branch_descr()) or ''
342
343     def set_description(self, line):
344         if line:
345             config.set(self.__branch_descr(), line)
346         else:
347             config.unset(self.__branch_descr())
348
349     def head_top_equal(self):
350         """Return true if the head and the top are the same
351         """
352         crt = self.get_current_patch()
353         if not crt:
354             # we don't care, no patches applied
355             return True
356         return git.get_head() == crt.get_top()
357
358     def is_initialised(self):
359         """Checks if series is already initialised
360         """
361         return bool(config.get(self.format_version_key()))
362
363
364 def shortlog(patches):
365     log = ''.join(Run('git-log', '--pretty=short',
366                       p.get_top(), '^%s' % p.get_bottom()).raw_output()
367                   for p in patches)
368     return Run('git-shortlog').raw_input(log).raw_output()
369
370 class Series(PatchSet):
371     """Class including the operations on series
372     """
373     def __init__(self, name = None):
374         """Takes a series name as the parameter.
375         """
376         PatchSet.__init__(self, name)
377
378         # Update the branch to the latest format version if it is
379         # initialized, but don't touch it if it isn't.
380         self.update_to_current_format_version()
381
382         self.__refs_base = 'refs/patches/%s' % self.get_name()
383
384         self.__applied_file = os.path.join(self._dir(), 'applied')
385         self.__unapplied_file = os.path.join(self._dir(), 'unapplied')
386         self.__hidden_file = os.path.join(self._dir(), 'hidden')
387
388         # where this series keeps its patches
389         self.__patch_dir = os.path.join(self._dir(), 'patches')
390
391         # trash directory
392         self.__trash_dir = os.path.join(self._dir(), 'trash')
393
394     def format_version_key(self):
395         return 'branch.%s.stgit.stackformatversion' % self.get_name()
396
397     def update_to_current_format_version(self):
398         """Update a potentially older StGIT directory structure to the
399         latest version. Note: This function should depend as little as
400         possible on external functions that may change during a format
401         version bump, since it must remain able to process older formats."""
402
403         branch_dir = os.path.join(self._basedir(), 'patches', self.get_name())
404         def get_format_version():
405             """Return the integer format version number, or None if the
406             branch doesn't have any StGIT metadata at all, of any version."""
407             fv = config.get(self.format_version_key())
408             ofv = config.get('branch.%s.stgitformatversion' % self.get_name())
409             if fv:
410                 # Great, there's an explicitly recorded format version
411                 # number, which means that the branch is initialized and
412                 # of that exact version.
413                 return int(fv)
414             elif ofv:
415                 # Old name for the version info, upgrade it
416                 config.set(self.format_version_key(), ofv)
417                 config.unset('branch.%s.stgitformatversion' % self.get_name())
418                 return int(ofv)
419             elif os.path.isdir(os.path.join(branch_dir, 'patches')):
420                 # There's a .git/patches/<branch>/patches dirctory, which
421                 # means this is an initialized version 1 branch.
422                 return 1
423             elif os.path.isdir(branch_dir):
424                 # There's a .git/patches/<branch> directory, which means
425                 # this is an initialized version 0 branch.
426                 return 0
427             else:
428                 # The branch doesn't seem to be initialized at all.
429                 return None
430         def set_format_version(v):
431             out.info('Upgraded branch %s to format version %d' % (self.get_name(), v))
432             config.set(self.format_version_key(), '%d' % v)
433         def mkdir(d):
434             if not os.path.isdir(d):
435                 os.makedirs(d)
436         def rm(f):
437             if os.path.exists(f):
438                 os.remove(f)
439         def rm_ref(ref):
440             if git.ref_exists(ref):
441                 git.delete_ref(ref)
442
443         # Update 0 -> 1.
444         if get_format_version() == 0:
445             mkdir(os.path.join(branch_dir, 'trash'))
446             patch_dir = os.path.join(branch_dir, 'patches')
447             mkdir(patch_dir)
448             refs_base = 'refs/patches/%s' % self.get_name()
449             for patch in (file(os.path.join(branch_dir, 'unapplied')).readlines()
450                           + file(os.path.join(branch_dir, 'applied')).readlines()):
451                 patch = patch.strip()
452                 os.rename(os.path.join(branch_dir, patch),
453                           os.path.join(patch_dir, patch))
454                 Patch(patch, patch_dir, refs_base).update_top_ref()
455             set_format_version(1)
456
457         # Update 1 -> 2.
458         if get_format_version() == 1:
459             desc_file = os.path.join(branch_dir, 'description')
460             if os.path.isfile(desc_file):
461                 desc = read_string(desc_file)
462                 if desc:
463                     config.set('branch.%s.description' % self.get_name(), desc)
464                 rm(desc_file)
465             rm(os.path.join(branch_dir, 'current'))
466             rm_ref('refs/bases/%s' % self.get_name())
467             set_format_version(2)
468
469         # Make sure we're at the latest version.
470         if not get_format_version() in [None, FORMAT_VERSION]:
471             raise StackException('Branch %s is at format version %d, expected %d'
472                                  % (self.get_name(), get_format_version(), FORMAT_VERSION))
473
474     def __patch_name_valid(self, name):
475         """Raise an exception if the patch name is not valid.
476         """
477         if not name or re.search('[^\w.-]', name):
478             raise StackException, 'Invalid patch name: "%s"' % name
479
480     def get_patch(self, name):
481         """Return a Patch object for the given name
482         """
483         return Patch(name, self.__patch_dir, self.__refs_base)
484
485     def get_current_patch(self):
486         """Return a Patch object representing the topmost patch, or
487         None if there is no such patch."""
488         crt = self.get_current()
489         if not crt:
490             return None
491         return self.get_patch(crt)
492
493     def get_current(self):
494         """Return the name of the topmost patch, or None if there is
495         no such patch."""
496         try:
497             applied = self.get_applied()
498         except StackException:
499             # No "applied" file: branch is not initialized.
500             return None
501         try:
502             return applied[-1]
503         except IndexError:
504             # No patches applied.
505             return None
506
507     def get_applied(self):
508         if not os.path.isfile(self.__applied_file):
509             raise StackException, 'Branch "%s" not initialised' % self.get_name()
510         return read_strings(self.__applied_file)
511
512     def set_applied(self, applied):
513         write_strings(self.__applied_file, applied)
514
515     def get_unapplied(self):
516         if not os.path.isfile(self.__unapplied_file):
517             raise StackException, 'Branch "%s" not initialised' % self.get_name()
518         return read_strings(self.__unapplied_file)
519
520     def set_unapplied(self, unapplied):
521         write_strings(self.__unapplied_file, unapplied)
522
523     def get_hidden(self):
524         if not os.path.isfile(self.__hidden_file):
525             return []
526         return read_strings(self.__hidden_file)
527
528     def get_base(self):
529         # Return the parent of the bottommost patch, if there is one.
530         if os.path.isfile(self.__applied_file):
531             bottommost = file(self.__applied_file).readline().strip()
532             if bottommost:
533                 return self.get_patch(bottommost).get_bottom()
534         # No bottommost patch, so just return HEAD
535         return git.get_head()
536
537     def get_parent_remote(self):
538         value = config.get('branch.%s.remote' % self.get_name())
539         if value:
540             return value
541         elif 'origin' in git.remotes_list():
542             out.note(('No parent remote declared for stack "%s",'
543                       ' defaulting to "origin".' % self.get_name()),
544                      ('Consider setting "branch.%s.remote" and'
545                       ' "branch.%s.merge" with "git config".'
546                       % (self.get_name(), self.get_name())))
547             return 'origin'
548         else:
549             raise StackException, 'Cannot find a parent remote for "%s"' % self.get_name()
550
551     def __set_parent_remote(self, remote):
552         value = config.set('branch.%s.remote' % self.get_name(), remote)
553
554     def get_parent_branch(self):
555         value = config.get('branch.%s.stgit.parentbranch' % self.get_name())
556         if value:
557             return value
558         elif git.rev_parse('heads/origin'):
559             out.note(('No parent branch declared for stack "%s",'
560                       ' defaulting to "heads/origin".' % self.get_name()),
561                      ('Consider setting "branch.%s.stgit.parentbranch"'
562                       ' with "git config".' % self.get_name()))
563             return 'heads/origin'
564         else:
565             raise StackException, 'Cannot find a parent branch for "%s"' % self.get_name()
566
567     def __set_parent_branch(self, name):
568         if config.get('branch.%s.remote' % self.get_name()):
569             # Never set merge if remote is not set to avoid
570             # possibly-erroneous lookups into 'origin'
571             config.set('branch.%s.merge' % self.get_name(), name)
572         config.set('branch.%s.stgit.parentbranch' % self.get_name(), name)
573
574     def set_parent(self, remote, localbranch):
575         if localbranch:
576             if remote:
577                 self.__set_parent_remote(remote)
578             self.__set_parent_branch(localbranch)
579         # We'll enforce this later
580 #         else:
581 #             raise StackException, 'Parent branch (%s) should be specified for %s' % localbranch, self.get_name()
582
583     def __patch_is_current(self, patch):
584         return patch.get_name() == self.get_current()
585
586     def patch_applied(self, name):
587         """Return true if the patch exists in the applied list
588         """
589         return name in self.get_applied()
590
591     def patch_unapplied(self, name):
592         """Return true if the patch exists in the unapplied list
593         """
594         return name in self.get_unapplied()
595
596     def patch_hidden(self, name):
597         """Return true if the patch is hidden.
598         """
599         return name in self.get_hidden()
600
601     def patch_exists(self, name):
602         """Return true if there is a patch with the given name, false
603         otherwise."""
604         return self.patch_applied(name) or self.patch_unapplied(name) \
605                or self.patch_hidden(name)
606
607     def init(self, create_at=False, parent_remote=None, parent_branch=None):
608         """Initialises the stgit series
609         """
610         if self.is_initialised():
611             raise StackException, '%s already initialized' % self.get_name()
612         for d in [self._dir()]:
613             if os.path.exists(d):
614                 raise StackException, '%s already exists' % d
615
616         if (create_at!=False):
617             git.create_branch(self.get_name(), create_at)
618
619         os.makedirs(self.__patch_dir)
620
621         self.set_parent(parent_remote, parent_branch)
622
623         self.create_empty_field('applied')
624         self.create_empty_field('unapplied')
625         self._set_field('orig-base', git.get_head())
626
627         config.set(self.format_version_key(), str(FORMAT_VERSION))
628
629     def rename(self, to_name):
630         """Renames a series
631         """
632         to_stack = Series(to_name)
633
634         if to_stack.is_initialised():
635             raise StackException, '"%s" already exists' % to_stack.get_name()
636
637         patches = self.get_applied() + self.get_unapplied()
638
639         git.rename_branch(self.get_name(), to_name)
640
641         for patch in patches:
642             git.rename_ref('refs/patches/%s/%s' % (self.get_name(), patch),
643                            'refs/patches/%s/%s' % (to_name, patch))
644             git.rename_ref('refs/patches/%s/%s.log' % (self.get_name(), patch),
645                            'refs/patches/%s/%s.log' % (to_name, patch))
646         if os.path.isdir(self._dir()):
647             rename(os.path.join(self._basedir(), 'patches'),
648                    self.get_name(), to_stack.get_name())
649
650         # Rename the config section
651         for k in ['branch.%s', 'branch.%s.stgit']:
652             config.rename_section(k % self.get_name(), k % to_name)
653
654         self.__init__(to_name)
655
656     def clone(self, target_series):
657         """Clones a series
658         """
659         try:
660             # allow cloning of branches not under StGIT control
661             base = self.get_base()
662         except:
663             base = git.get_head()
664         Series(target_series).init(create_at = base)
665         new_series = Series(target_series)
666
667         # generate an artificial description file
668         new_series.set_description('clone of "%s"' % self.get_name())
669
670         # clone self's entire series as unapplied patches
671         try:
672             # allow cloning of branches not under StGIT control
673             applied = self.get_applied()
674             unapplied = self.get_unapplied()
675             patches = applied + unapplied
676             patches.reverse()
677         except:
678             patches = applied = unapplied = []
679         for p in patches:
680             patch = self.get_patch(p)
681             newpatch = new_series.new_patch(p, message = patch.get_description(),
682                                             can_edit = False, unapplied = True,
683                                             bottom = patch.get_bottom(),
684                                             top = patch.get_top(),
685                                             author_name = patch.get_authname(),
686                                             author_email = patch.get_authemail(),
687                                             author_date = patch.get_authdate())
688             if patch.get_log():
689                 out.info('Setting log to %s' %  patch.get_log())
690                 newpatch.set_log(patch.get_log())
691             else:
692                 out.info('No log for %s' % p)
693
694         # fast forward the cloned series to self's top
695         new_series.forward_patches(applied)
696
697         # Clone parent informations
698         value = config.get('branch.%s.remote' % self.get_name())
699         if value:
700             config.set('branch.%s.remote' % target_series, value)
701
702         value = config.get('branch.%s.merge' % self.get_name())
703         if value:
704             config.set('branch.%s.merge' % target_series, value)
705
706         value = config.get('branch.%s.stgit.parentbranch' % self.get_name())
707         if value:
708             config.set('branch.%s.stgit.parentbranch' % target_series, value)
709
710     def delete(self, force = False):
711         """Deletes an stgit series
712         """
713         if self.is_initialised():
714             patches = self.get_unapplied() + self.get_applied()
715             if not force and patches:
716                 raise StackException, \
717                       'Cannot delete: the series still contains patches'
718             for p in patches:
719                 self.get_patch(p).delete()
720
721             # remove the trash directory if any
722             if os.path.exists(self.__trash_dir):
723                 for fname in os.listdir(self.__trash_dir):
724                     os.remove(os.path.join(self.__trash_dir, fname))
725                 os.rmdir(self.__trash_dir)
726
727             # FIXME: find a way to get rid of those manual removals
728             # (move functionality to StgitObject ?)
729             if os.path.exists(self.__applied_file):
730                 os.remove(self.__applied_file)
731             if os.path.exists(self.__unapplied_file):
732                 os.remove(self.__unapplied_file)
733             if os.path.exists(self.__hidden_file):
734                 os.remove(self.__hidden_file)
735             if os.path.exists(self._dir()+'/orig-base'):
736                 os.remove(self._dir()+'/orig-base')
737
738             if not os.listdir(self.__patch_dir):
739                 os.rmdir(self.__patch_dir)
740             else:
741                 out.warn('Patch directory %s is not empty' % self.__patch_dir)
742
743             try:
744                 os.removedirs(self._dir())
745             except OSError:
746                 raise StackException('Series directory %s is not empty'
747                                      % self._dir())
748
749             try:
750                 git.delete_branch(self.get_name())
751             except GitException:
752                 out.warn('Could not delete branch "%s"' % self.get_name())
753
754         config.remove_section('branch.%s' % self.get_name())
755         config.remove_section('branch.%s.stgit' % self.get_name())
756
757     def refresh_patch(self, files = None, message = None, edit = False,
758                       show_patch = False,
759                       cache_update = True,
760                       author_name = None, author_email = None,
761                       author_date = None,
762                       committer_name = None, committer_email = None,
763                       backup = False, sign_str = None, log = 'refresh',
764                       notes = None, bottom = None):
765         """Generates a new commit for the topmost patch
766         """
767         patch = self.get_current_patch()
768         if not patch:
769             raise StackException, 'No patches applied'
770
771         descr = patch.get_description()
772         if not (message or descr):
773             edit = True
774             descr = ''
775         elif message:
776             descr = message
777
778         # TODO: move this out of the stgit.stack module, it is really
779         # for higher level commands to handle the user interaction
780         if not message and edit:
781             descr = edit_file(self, descr.rstrip(), \
782                               'Please edit the description for patch "%s" ' \
783                               'above.' % patch.get_name(), show_patch)
784
785         if not author_name:
786             author_name = patch.get_authname()
787         if not author_email:
788             author_email = patch.get_authemail()
789         if not author_date:
790             author_date = patch.get_authdate()
791         if not committer_name:
792             committer_name = patch.get_commname()
793         if not committer_email:
794             committer_email = patch.get_commemail()
795
796         descr = add_sign_line(descr, sign_str, committer_name, committer_email)
797
798         if not bottom:
799             bottom = patch.get_bottom()
800
801         commit_id = git.commit(files = files,
802                                message = descr, parents = [bottom],
803                                cache_update = cache_update,
804                                allowempty = True,
805                                author_name = author_name,
806                                author_email = author_email,
807                                author_date = author_date,
808                                committer_name = committer_name,
809                                committer_email = committer_email)
810
811         patch.set_bottom(bottom, backup = backup)
812         patch.set_top(commit_id, backup = backup)
813         patch.set_description(descr)
814         patch.set_authname(author_name)
815         patch.set_authemail(author_email)
816         patch.set_authdate(author_date)
817         patch.set_commname(committer_name)
818         patch.set_commemail(committer_email)
819
820         if log:
821             self.log_patch(patch, log, notes)
822
823         return commit_id
824
825     def undo_refresh(self):
826         """Undo the patch boundaries changes caused by 'refresh'
827         """
828         name = self.get_current()
829         assert(name)
830
831         patch = self.get_patch(name)
832         old_bottom = patch.get_old_bottom()
833         old_top = patch.get_old_top()
834
835         # the bottom of the patch is not changed by refresh. If the
836         # old_bottom is different, there wasn't any previous 'refresh'
837         # command (probably only a 'push')
838         if old_bottom != patch.get_bottom() or old_top == patch.get_top():
839             raise StackException, 'No undo information available'
840
841         git.reset(tree_id = old_top, check_out = False)
842         if patch.restore_old_boundaries():
843             self.log_patch(patch, 'undo')
844
845     def new_patch(self, name, message = None, can_edit = True,
846                   unapplied = False, show_patch = False,
847                   top = None, bottom = None, commit = True,
848                   author_name = None, author_email = None, author_date = None,
849                   committer_name = None, committer_email = None,
850                   before_existing = False, sign_str = None):
851         """Creates a new patch, either pointing to an existing commit object,
852         or by creating a new commit object.
853         """
854
855         assert commit or (top and bottom)
856         assert not before_existing or (top and bottom)
857         assert not (commit and before_existing)
858         assert (top and bottom) or (not top and not bottom)
859         assert not top or (bottom == git.get_commit(top).get_parent())
860
861         if name != None:
862             self.__patch_name_valid(name)
863             if self.patch_exists(name):
864                 raise StackException, 'Patch "%s" already exists' % name
865
866         # TODO: move this out of the stgit.stack module, it is really
867         # for higher level commands to handle the user interaction
868         def sign(msg):
869             return add_sign_line(msg, sign_str,
870                                  committer_name or git.committer().name,
871                                  committer_email or git.committer().email)
872         if not message and can_edit:
873             descr = edit_file(
874                 self, sign(''),
875                 'Please enter the description for the patch above.',
876                 show_patch)
877         else:
878             descr = sign(message)
879
880         head = git.get_head()
881
882         if name == None:
883             name = make_patch_name(descr, self.patch_exists)
884
885         patch = self.get_patch(name)
886         patch.create()
887
888         patch.set_description(descr)
889         patch.set_authname(author_name)
890         patch.set_authemail(author_email)
891         patch.set_authdate(author_date)
892         patch.set_commname(committer_name)
893         patch.set_commemail(committer_email)
894
895         if before_existing:
896             insert_string(self.__applied_file, patch.get_name())
897         elif unapplied:
898             patches = [patch.get_name()] + self.get_unapplied()
899             write_strings(self.__unapplied_file, patches)
900             set_head = False
901         else:
902             append_string(self.__applied_file, patch.get_name())
903             set_head = True
904
905         if commit:
906             if top:
907                 top_commit = git.get_commit(top)
908             else:
909                 bottom = head
910                 top_commit = git.get_commit(head)
911
912             # create a commit for the patch (may be empty if top == bottom);
913             # only commit on top of the current branch
914             assert(unapplied or bottom == head)
915             commit_id = git.commit(message = descr, parents = [bottom],
916                                    cache_update = False,
917                                    tree_id = top_commit.get_tree(),
918                                    allowempty = True, set_head = set_head,
919                                    author_name = author_name,
920                                    author_email = author_email,
921                                    author_date = author_date,
922                                    committer_name = committer_name,
923                                    committer_email = committer_email)
924             # set the patch top to the new commit
925             patch.set_bottom(bottom)
926             patch.set_top(commit_id)
927         else:
928             assert top != bottom
929             patch.set_bottom(bottom)
930             patch.set_top(top)
931
932         self.log_patch(patch, 'new')
933
934         return patch
935
936     def delete_patch(self, name):
937         """Deletes a patch
938         """
939         self.__patch_name_valid(name)
940         patch = self.get_patch(name)
941
942         if self.__patch_is_current(patch):
943             self.pop_patch(name)
944         elif self.patch_applied(name):
945             raise StackException, 'Cannot remove an applied patch, "%s", ' \
946                   'which is not current' % name
947         elif not name in self.get_unapplied():
948             raise StackException, 'Unknown patch "%s"' % name
949
950         # save the commit id to a trash file
951         write_string(os.path.join(self.__trash_dir, name), patch.get_top())
952
953         patch.delete()
954
955         unapplied = self.get_unapplied()
956         unapplied.remove(name)
957         write_strings(self.__unapplied_file, unapplied)
958
959     def forward_patches(self, names):
960         """Try to fast-forward an array of patches.
961
962         On return, patches in names[0:returned_value] have been pushed on the
963         stack. Apply the rest with push_patch
964         """
965         unapplied = self.get_unapplied()
966
967         forwarded = 0
968         top = git.get_head()
969
970         for name in names:
971             assert(name in unapplied)
972
973             patch = self.get_patch(name)
974
975             head = top
976             bottom = patch.get_bottom()
977             top = patch.get_top()
978
979             # top != bottom always since we have a commit for each patch
980             if head == bottom:
981                 # reset the backup information. No logging since the
982                 # patch hasn't changed
983                 patch.set_bottom(head, backup = True)
984                 patch.set_top(top, backup = True)
985
986             else:
987                 head_tree = git.get_commit(head).get_tree()
988                 bottom_tree = git.get_commit(bottom).get_tree()
989                 if head_tree == bottom_tree:
990                     # We must just reparent this patch and create a new commit
991                     # for it
992                     descr = patch.get_description()
993                     author_name = patch.get_authname()
994                     author_email = patch.get_authemail()
995                     author_date = patch.get_authdate()
996                     committer_name = patch.get_commname()
997                     committer_email = patch.get_commemail()
998
999                     top_tree = git.get_commit(top).get_tree()
1000
1001                     top = git.commit(message = descr, parents = [head],
1002                                      cache_update = False,
1003                                      tree_id = top_tree,
1004                                      allowempty = True,
1005                                      author_name = author_name,
1006                                      author_email = author_email,
1007                                      author_date = author_date,
1008                                      committer_name = committer_name,
1009                                      committer_email = committer_email)
1010
1011                     patch.set_bottom(head, backup = True)
1012                     patch.set_top(top, backup = True)
1013
1014                     self.log_patch(patch, 'push(f)')
1015                 else:
1016                     top = head
1017                     # stop the fast-forwarding, must do a real merge
1018                     break
1019
1020             forwarded+=1
1021             unapplied.remove(name)
1022
1023         if forwarded == 0:
1024             return 0
1025
1026         git.switch(top)
1027
1028         append_strings(self.__applied_file, names[0:forwarded])
1029         write_strings(self.__unapplied_file, unapplied)
1030
1031         return forwarded
1032
1033     def merged_patches(self, names):
1034         """Test which patches were merged upstream by reverse-applying
1035         them in reverse order. The function returns the list of
1036         patches detected to have been applied. The state of the tree
1037         is restored to the original one
1038         """
1039         patches = [self.get_patch(name) for name in names]
1040         patches.reverse()
1041
1042         merged = []
1043         for p in patches:
1044             if git.apply_diff(p.get_top(), p.get_bottom()):
1045                 merged.append(p.get_name())
1046         merged.reverse()
1047
1048         git.reset()
1049
1050         return merged
1051
1052     def push_empty_patch(self, name):
1053         """Pushes an empty patch on the stack
1054         """
1055         unapplied = self.get_unapplied()
1056         assert(name in unapplied)
1057
1058         # patch = self.get_patch(name)
1059         head = git.get_head()
1060
1061         append_string(self.__applied_file, name)
1062
1063         unapplied.remove(name)
1064         write_strings(self.__unapplied_file, unapplied)
1065
1066         self.refresh_patch(bottom = head, cache_update = False, log = 'push(m)')
1067
1068     def push_patch(self, name):
1069         """Pushes a patch on the stack
1070         """
1071         unapplied = self.get_unapplied()
1072         assert(name in unapplied)
1073
1074         patch = self.get_patch(name)
1075
1076         head = git.get_head()
1077         bottom = patch.get_bottom()
1078         top = patch.get_top()
1079         # top != bottom always since we have a commit for each patch
1080
1081         if head == bottom:
1082             # A fast-forward push. Just reset the backup
1083             # information. No need for logging
1084             patch.set_bottom(bottom, backup = True)
1085             patch.set_top(top, backup = True)
1086
1087             git.switch(top)
1088             append_string(self.__applied_file, name)
1089
1090             unapplied.remove(name)
1091             write_strings(self.__unapplied_file, unapplied)
1092             return False
1093
1094         # Need to create a new commit an merge in the old patch
1095         ex = None
1096         modified = False
1097
1098         # Try the fast applying first. If this fails, fall back to the
1099         # three-way merge
1100         if not git.apply_diff(bottom, top):
1101             # if git.apply_diff() fails, the patch requires a diff3
1102             # merge and can be reported as modified
1103             modified = True
1104
1105             # merge can fail but the patch needs to be pushed
1106             try:
1107                 git.merge(bottom, head, top, recursive = True)
1108             except git.GitException, ex:
1109                 out.error('The merge failed during "push".',
1110                           'Use "refresh" after fixing the conflicts or'
1111                           ' revert the operation with "push --undo".')
1112
1113         append_string(self.__applied_file, name)
1114
1115         unapplied.remove(name)
1116         write_strings(self.__unapplied_file, unapplied)
1117
1118         if not ex:
1119             # if the merge was OK and no conflicts, just refresh the patch
1120             # The GIT cache was already updated by the merge operation
1121             if modified:
1122                 log = 'push(m)'
1123             else:
1124                 log = 'push'
1125             self.refresh_patch(bottom = head, cache_update = False, log = log)
1126         else:
1127             # we store the correctly merged files only for
1128             # tracking the conflict history. Note that the
1129             # git.merge() operations should always leave the index
1130             # in a valid state (i.e. only stage 0 files)
1131             self.refresh_patch(bottom = head, cache_update = False,
1132                                log = 'push(c)')
1133             raise StackException, str(ex)
1134
1135         return modified
1136
1137     def undo_push(self):
1138         name = self.get_current()
1139         assert(name)
1140
1141         patch = self.get_patch(name)
1142         old_bottom = patch.get_old_bottom()
1143         old_top = patch.get_old_top()
1144
1145         # the top of the patch is changed by a push operation only
1146         # together with the bottom (otherwise the top was probably
1147         # modified by 'refresh'). If they are both unchanged, there
1148         # was a fast forward
1149         if old_bottom == patch.get_bottom() and old_top != patch.get_top():
1150             raise StackException, 'No undo information available'
1151
1152         git.reset()
1153         self.pop_patch(name)
1154         ret = patch.restore_old_boundaries()
1155         if ret:
1156             self.log_patch(patch, 'undo')
1157
1158         return ret
1159
1160     def pop_patch(self, name, keep = False):
1161         """Pops the top patch from the stack
1162         """
1163         applied = self.get_applied()
1164         applied.reverse()
1165         assert(name in applied)
1166
1167         patch = self.get_patch(name)
1168
1169         if git.get_head_file() == self.get_name():
1170             if keep and not git.apply_diff(git.get_head(), patch.get_bottom()):
1171                 raise StackException(
1172                     'Failed to pop patches while preserving the local changes')
1173             git.switch(patch.get_bottom(), keep)
1174         else:
1175             git.set_branch(self.get_name(), patch.get_bottom())
1176
1177         # save the new applied list
1178         idx = applied.index(name) + 1
1179
1180         popped = applied[:idx]
1181         popped.reverse()
1182         unapplied = popped + self.get_unapplied()
1183         write_strings(self.__unapplied_file, unapplied)
1184
1185         del applied[:idx]
1186         applied.reverse()
1187         write_strings(self.__applied_file, applied)
1188
1189     def empty_patch(self, name):
1190         """Returns True if the patch is empty
1191         """
1192         self.__patch_name_valid(name)
1193         patch = self.get_patch(name)
1194         bottom = patch.get_bottom()
1195         top = patch.get_top()
1196
1197         if bottom == top:
1198             return True
1199         elif git.get_commit(top).get_tree() \
1200                  == git.get_commit(bottom).get_tree():
1201             return True
1202
1203         return False
1204
1205     def rename_patch(self, oldname, newname):
1206         self.__patch_name_valid(newname)
1207
1208         applied = self.get_applied()
1209         unapplied = self.get_unapplied()
1210
1211         if oldname == newname:
1212             raise StackException, '"To" name and "from" name are the same'
1213
1214         if newname in applied or newname in unapplied:
1215             raise StackException, 'Patch "%s" already exists' % newname
1216
1217         if oldname in unapplied:
1218             self.get_patch(oldname).rename(newname)
1219             unapplied[unapplied.index(oldname)] = newname
1220             write_strings(self.__unapplied_file, unapplied)
1221         elif oldname in applied:
1222             self.get_patch(oldname).rename(newname)
1223
1224             applied[applied.index(oldname)] = newname
1225             write_strings(self.__applied_file, applied)
1226         else:
1227             raise StackException, 'Unknown patch "%s"' % oldname
1228
1229     def log_patch(self, patch, message, notes = None):
1230         """Generate a log commit for a patch
1231         """
1232         top = git.get_commit(patch.get_top())
1233         old_log = patch.get_log()
1234
1235         if message is None:
1236             # replace the current log entry
1237             if not old_log:
1238                 raise StackException, \
1239                       'No log entry to annotate for patch "%s"' \
1240                       % patch.get_name()
1241             replace = True
1242             log_commit = git.get_commit(old_log)
1243             msg = log_commit.get_log().split('\n')[0]
1244             log_parent = log_commit.get_parent()
1245             if log_parent:
1246                 parents = [log_parent]
1247             else:
1248                 parents = []
1249         else:
1250             # generate a new log entry
1251             replace = False
1252             msg = '%s\t%s' % (message, top.get_id_hash())
1253             if old_log:
1254                 parents = [old_log]
1255             else:
1256                 parents = []
1257
1258         if notes:
1259             msg += '\n\n' + notes
1260
1261         log = git.commit(message = msg, parents = parents,
1262                          cache_update = False, tree_id = top.get_tree(),
1263                          allowempty = True)
1264         patch.set_log(log)
1265
1266     def hide_patch(self, name):
1267         """Add the patch to the hidden list.
1268         """
1269         unapplied = self.get_unapplied()
1270         if name not in unapplied:
1271             # keep the checking order for backward compatibility with
1272             # the old hidden patches functionality
1273             if self.patch_applied(name):
1274                 raise StackException, 'Cannot hide applied patch "%s"' % name
1275             elif self.patch_hidden(name):
1276                 raise StackException, 'Patch "%s" already hidden' % name
1277             else:
1278                 raise StackException, 'Unknown patch "%s"' % name
1279
1280         if not self.patch_hidden(name):
1281             # check needed for backward compatibility with the old
1282             # hidden patches functionality
1283             append_string(self.__hidden_file, name)
1284
1285         unapplied.remove(name)
1286         write_strings(self.__unapplied_file, unapplied)
1287
1288     def unhide_patch(self, name):
1289         """Remove the patch from the hidden list.
1290         """
1291         hidden = self.get_hidden()
1292         if not name in hidden:
1293             if self.patch_applied(name) or self.patch_unapplied(name):
1294                 raise StackException, 'Patch "%s" not hidden' % name
1295             else:
1296                 raise StackException, 'Unknown patch "%s"' % name
1297
1298         hidden.remove(name)
1299         write_strings(self.__hidden_file, hidden)
1300
1301         if not self.patch_applied(name) and not self.patch_unapplied(name):
1302             # check needed for backward compatibility with the old
1303             # hidden patches functionality
1304             append_string(self.__unapplied_file, name)