chiark / gitweb /
37ffb6a6ce6e2e6101b055c83203cb92651a96af
[stgit] / stgit / stack.py
1 """Basic quilt-like functionality
2 """
3
4 __copyright__ = """
5 Copyright (C) 2005, Catalin Marinas <catalin.marinas@gmail.com>
6
7 This program is free software; you can redistribute it and/or modify
8 it under the terms of the GNU General Public License version 2 as
9 published by the Free Software Foundation.
10
11 This program is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14 GNU General Public License for more details.
15
16 You should have received a copy of the GNU General Public License
17 along with this program; if not, write to the Free Software
18 Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
19 """
20
21 import sys, os, re
22 from email.Utils import formatdate
23
24 from stgit.utils import *
25 from stgit.out import *
26 from stgit import git, basedir, templates
27 from stgit.config import config
28 from shutil import copyfile
29
30
31 # stack exception class
32 class StackException(Exception):
33     pass
34
35 class FilterUntil:
36     def __init__(self):
37         self.should_print = True
38     def __call__(self, x, until_test, prefix):
39         if until_test(x):
40             self.should_print = False
41         if self.should_print:
42             return x[0:len(prefix)] != prefix
43         return False
44
45 #
46 # Functions
47 #
48 __comment_prefix = 'STG:'
49 __patch_prefix = 'STG_PATCH:'
50
51 def __clean_comments(f):
52     """Removes lines marked for status in a commit file
53     """
54     f.seek(0)
55
56     # remove status-prefixed lines
57     lines = f.readlines()
58
59     patch_filter = FilterUntil()
60     until_test = lambda t: t == (__patch_prefix + '\n')
61     lines = [l for l in lines if patch_filter(l, until_test, __comment_prefix)]
62
63     # remove empty lines at the end
64     while len(lines) != 0 and lines[-1] == '\n':
65         del lines[-1]
66
67     f.seek(0); f.truncate()
68     f.writelines(lines)
69
70 # TODO: move this out of the stgit.stack module, it is really for
71 # higher level commands to handle the user interaction
72 def edit_file(series, line, comment, show_patch = True):
73     fname = '.stgitmsg.txt'
74     tmpl = templates.get_template('patchdescr.tmpl')
75
76     f = file(fname, 'w+')
77     if line:
78         print >> f, line
79     elif tmpl:
80         print >> f, tmpl,
81     else:
82         print >> f
83     print >> f, __comment_prefix, comment
84     print >> f, __comment_prefix, \
85           'Lines prefixed with "%s" will be automatically removed.' \
86           % __comment_prefix
87     print >> f, __comment_prefix, \
88           'Trailing empty lines will be automatically removed.'
89
90     if show_patch:
91        print >> f, __patch_prefix
92        # series.get_patch(series.get_current()).get_top()
93        diff_str = git.diff(rev1 = series.get_patch(series.get_current()).get_bottom())
94        f.write(diff_str)
95
96     #Vim modeline must be near the end.
97     print >> f, __comment_prefix, 'vi: set textwidth=75 filetype=diff nobackup:'
98     f.close()
99
100     call_editor(fname)
101
102     f = file(fname, 'r+')
103
104     __clean_comments(f)
105     f.seek(0)
106     result = f.read()
107
108     f.close()
109     os.remove(fname)
110
111     return result
112
113 #
114 # Classes
115 #
116
117 class StgitObject:
118     """An object with stgit-like properties stored as files in a directory
119     """
120     def _set_dir(self, dir):
121         self.__dir = dir
122     def _dir(self):
123         return self.__dir
124
125     def create_empty_field(self, name):
126         create_empty_file(os.path.join(self.__dir, name))
127
128     def _get_field(self, name, multiline = False):
129         id_file = os.path.join(self.__dir, name)
130         if os.path.isfile(id_file):
131             line = read_string(id_file, multiline)
132             if line == '':
133                 return None
134             else:
135                 return line
136         else:
137             return None
138
139     def _set_field(self, name, value, multiline = False):
140         fname = os.path.join(self.__dir, name)
141         if value and value != '':
142             write_string(fname, value, multiline)
143         elif os.path.isfile(fname):
144             os.remove(fname)
145
146     
147 class Patch(StgitObject):
148     """Basic patch implementation
149     """
150     def __init_refs(self):
151         self.__top_ref = self.__refs_base + '/' + self.__name
152         self.__log_ref = self.__top_ref + '.log'
153
154     def __init__(self, name, series_dir, refs_base):
155         self.__series_dir = series_dir
156         self.__name = name
157         self._set_dir(os.path.join(self.__series_dir, self.__name))
158         self.__refs_base = refs_base
159         self.__init_refs()
160
161     def create(self):
162         os.mkdir(self._dir())
163         self.create_empty_field('bottom')
164         self.create_empty_field('top')
165
166     def delete(self):
167         for f in os.listdir(self._dir()):
168             os.remove(os.path.join(self._dir(), f))
169         os.rmdir(self._dir())
170         git.delete_ref(self.__top_ref)
171         if git.ref_exists(self.__log_ref):
172             git.delete_ref(self.__log_ref)
173
174     def get_name(self):
175         return self.__name
176
177     def rename(self, newname):
178         olddir = self._dir()
179         old_top_ref = self.__top_ref
180         old_log_ref = self.__log_ref
181         self.__name = newname
182         self._set_dir(os.path.join(self.__series_dir, self.__name))
183         self.__init_refs()
184
185         git.rename_ref(old_top_ref, self.__top_ref)
186         if git.ref_exists(old_log_ref):
187             git.rename_ref(old_log_ref, self.__log_ref)
188         os.rename(olddir, self._dir())
189
190     def __update_top_ref(self, ref):
191         git.set_ref(self.__top_ref, ref)
192
193     def __update_log_ref(self, ref):
194         git.set_ref(self.__log_ref, ref)
195
196     def update_top_ref(self):
197         top = self.get_top()
198         if top:
199             self.__update_top_ref(top)
200
201     def get_old_bottom(self):
202         return self._get_field('bottom.old')
203
204     def get_bottom(self):
205         return self._get_field('bottom')
206
207     def set_bottom(self, value, backup = False):
208         if backup:
209             curr = self._get_field('bottom')
210             self._set_field('bottom.old', curr)
211         self._set_field('bottom', value)
212
213     def get_old_top(self):
214         return self._get_field('top.old')
215
216     def get_top(self):
217         return self._get_field('top')
218
219     def set_top(self, value, backup = False):
220         if backup:
221             curr = self._get_field('top')
222             self._set_field('top.old', curr)
223         self._set_field('top', value)
224         self.__update_top_ref(value)
225
226     def restore_old_boundaries(self):
227         bottom = self._get_field('bottom.old')
228         top = self._get_field('top.old')
229
230         if top and bottom:
231             self._set_field('bottom', bottom)
232             self._set_field('top', top)
233             self.__update_top_ref(top)
234             return True
235         else:
236             return False
237
238     def get_description(self):
239         return self._get_field('description', True)
240
241     def set_description(self, line):
242         self._set_field('description', line, True)
243
244     def get_authname(self):
245         return self._get_field('authname')
246
247     def set_authname(self, name):
248         self._set_field('authname', name or git.author().name)
249
250     def get_authemail(self):
251         return self._get_field('authemail')
252
253     def set_authemail(self, email):
254         self._set_field('authemail', email or git.author().email)
255
256     def get_authdate(self):
257         date = self._get_field('authdate')
258         if not date:
259             return date
260
261         if re.match('[0-9]+\s+[+-][0-9]+', date):
262             # Unix time (seconds) + time zone
263             secs_tz = date.split()
264             date = formatdate(int(secs_tz[0]))[:-5] + secs_tz[1]
265
266         return date
267
268     def set_authdate(self, date):
269         self._set_field('authdate', date or git.author().date)
270
271     def get_commname(self):
272         return self._get_field('commname')
273
274     def set_commname(self, name):
275         self._set_field('commname', name or git.committer().name)
276
277     def get_commemail(self):
278         return self._get_field('commemail')
279
280     def set_commemail(self, email):
281         self._set_field('commemail', email or git.committer().email)
282
283     def get_log(self):
284         return self._get_field('log')
285
286     def set_log(self, value, backup = False):
287         self._set_field('log', value)
288         self.__update_log_ref(value)
289
290 # The current StGIT metadata format version.
291 FORMAT_VERSION = 2
292
293 class PatchSet(StgitObject):
294     def __init__(self, name = None):
295         try:
296             if name:
297                 self.set_name (name)
298             else:
299                 self.set_name (git.get_head_file())
300             self.__base_dir = basedir.get()
301         except git.GitException, ex:
302             raise StackException, 'GIT tree not initialised: %s' % ex
303
304         self._set_dir(os.path.join(self.__base_dir, 'patches', self.get_name()))
305
306     def get_name(self):
307         return self.__name
308     def set_name(self, name):
309         self.__name = name
310
311     def _basedir(self):
312         return self.__base_dir
313
314     def get_head(self):
315         """Return the head of the branch
316         """
317         crt = self.get_current_patch()
318         if crt:
319             return crt.get_top()
320         else:
321             return self.get_base()
322
323     def get_protected(self):
324         return os.path.isfile(os.path.join(self._dir(), 'protected'))
325
326     def protect(self):
327         protect_file = os.path.join(self._dir(), 'protected')
328         if not os.path.isfile(protect_file):
329             create_empty_file(protect_file)
330
331     def unprotect(self):
332         protect_file = os.path.join(self._dir(), 'protected')
333         if os.path.isfile(protect_file):
334             os.remove(protect_file)
335
336     def __branch_descr(self):
337         return 'branch.%s.description' % self.get_name()
338
339     def get_description(self):
340         return config.get(self.__branch_descr()) or ''
341
342     def set_description(self, line):
343         if line:
344             config.set(self.__branch_descr(), line)
345         else:
346             config.unset(self.__branch_descr())
347
348     def head_top_equal(self):
349         """Return true if the head and the top are the same
350         """
351         crt = self.get_current_patch()
352         if not crt:
353             # we don't care, no patches applied
354             return True
355         return git.get_head() == crt.get_top()
356
357     def is_initialised(self):
358         """Checks if series is already initialised
359         """
360         return bool(config.get(self.format_version_key()))
361
362
363 def shortlog(patches):
364     log = ''.join(Run('git-log', '--pretty=short',
365                       p.get_top(), '^%s' % p.get_bottom()).raw_output()
366                   for p in patches)
367     return Run('git-shortlog').raw_input(log).raw_output()
368
369 class Series(PatchSet):
370     """Class including the operations on series
371     """
372     def __init__(self, name = None):
373         """Takes a series name as the parameter.
374         """
375         PatchSet.__init__(self, name)
376
377         # Update the branch to the latest format version if it is
378         # initialized, but don't touch it if it isn't.
379         self.update_to_current_format_version()
380
381         self.__refs_base = 'refs/patches/%s' % self.get_name()
382
383         self.__applied_file = os.path.join(self._dir(), 'applied')
384         self.__unapplied_file = os.path.join(self._dir(), 'unapplied')
385         self.__hidden_file = os.path.join(self._dir(), 'hidden')
386
387         # where this series keeps its patches
388         self.__patch_dir = os.path.join(self._dir(), 'patches')
389
390         # trash directory
391         self.__trash_dir = os.path.join(self._dir(), 'trash')
392
393     def format_version_key(self):
394         return 'branch.%s.stgit.stackformatversion' % self.get_name()
395
396     def update_to_current_format_version(self):
397         """Update a potentially older StGIT directory structure to the
398         latest version. Note: This function should depend as little as
399         possible on external functions that may change during a format
400         version bump, since it must remain able to process older formats."""
401
402         branch_dir = os.path.join(self._basedir(), 'patches', self.get_name())
403         def get_format_version():
404             """Return the integer format version number, or None if the
405             branch doesn't have any StGIT metadata at all, of any version."""
406             fv = config.get(self.format_version_key())
407             ofv = config.get('branch.%s.stgitformatversion' % self.get_name())
408             if fv:
409                 # Great, there's an explicitly recorded format version
410                 # number, which means that the branch is initialized and
411                 # of that exact version.
412                 return int(fv)
413             elif ofv:
414                 # Old name for the version info, upgrade it
415                 config.set(self.format_version_key(), ofv)
416                 config.unset('branch.%s.stgitformatversion' % self.get_name())
417                 return int(ofv)
418             elif os.path.isdir(os.path.join(branch_dir, 'patches')):
419                 # There's a .git/patches/<branch>/patches dirctory, which
420                 # means this is an initialized version 1 branch.
421                 return 1
422             elif os.path.isdir(branch_dir):
423                 # There's a .git/patches/<branch> directory, which means
424                 # this is an initialized version 0 branch.
425                 return 0
426             else:
427                 # The branch doesn't seem to be initialized at all.
428                 return None
429         def set_format_version(v):
430             out.info('Upgraded branch %s to format version %d' % (self.get_name(), v))
431             config.set(self.format_version_key(), '%d' % v)
432         def mkdir(d):
433             if not os.path.isdir(d):
434                 os.makedirs(d)
435         def rm(f):
436             if os.path.exists(f):
437                 os.remove(f)
438         def rm_ref(ref):
439             if git.ref_exists(ref):
440                 git.delete_ref(ref)
441
442         # Update 0 -> 1.
443         if get_format_version() == 0:
444             mkdir(os.path.join(branch_dir, 'trash'))
445             patch_dir = os.path.join(branch_dir, 'patches')
446             mkdir(patch_dir)
447             refs_base = 'refs/patches/%s' % self.get_name()
448             for patch in (file(os.path.join(branch_dir, 'unapplied')).readlines()
449                           + file(os.path.join(branch_dir, 'applied')).readlines()):
450                 patch = patch.strip()
451                 os.rename(os.path.join(branch_dir, patch),
452                           os.path.join(patch_dir, patch))
453                 Patch(patch, patch_dir, refs_base).update_top_ref()
454             set_format_version(1)
455
456         # Update 1 -> 2.
457         if get_format_version() == 1:
458             desc_file = os.path.join(branch_dir, 'description')
459             if os.path.isfile(desc_file):
460                 desc = read_string(desc_file)
461                 if desc:
462                     config.set('branch.%s.description' % self.get_name(), desc)
463                 rm(desc_file)
464             rm(os.path.join(branch_dir, 'current'))
465             rm_ref('refs/bases/%s' % self.get_name())
466             set_format_version(2)
467
468         # Make sure we're at the latest version.
469         if not get_format_version() in [None, FORMAT_VERSION]:
470             raise StackException('Branch %s is at format version %d, expected %d'
471                                  % (self.get_name(), get_format_version(), FORMAT_VERSION))
472
473     def __patch_name_valid(self, name):
474         """Raise an exception if the patch name is not valid.
475         """
476         if not name or re.search('[^\w.-]', name):
477             raise StackException, 'Invalid patch name: "%s"' % name
478
479     def get_patch(self, name):
480         """Return a Patch object for the given name
481         """
482         return Patch(name, self.__patch_dir, self.__refs_base)
483
484     def get_current_patch(self):
485         """Return a Patch object representing the topmost patch, or
486         None if there is no such patch."""
487         crt = self.get_current()
488         if not crt:
489             return None
490         return self.get_patch(crt)
491
492     def get_current(self):
493         """Return the name of the topmost patch, or None if there is
494         no such patch."""
495         try:
496             applied = self.get_applied()
497         except StackException:
498             # No "applied" file: branch is not initialized.
499             return None
500         try:
501             return applied[-1]
502         except IndexError:
503             # No patches applied.
504             return None
505
506     def get_applied(self):
507         if not os.path.isfile(self.__applied_file):
508             raise StackException, 'Branch "%s" not initialised' % self.get_name()
509         return read_strings(self.__applied_file)
510
511     def get_unapplied(self):
512         if not os.path.isfile(self.__unapplied_file):
513             raise StackException, 'Branch "%s" not initialised' % self.get_name()
514         return read_strings(self.__unapplied_file)
515
516     def get_hidden(self):
517         if not os.path.isfile(self.__hidden_file):
518             return []
519         return read_strings(self.__hidden_file)
520
521     def get_base(self):
522         # Return the parent of the bottommost patch, if there is one.
523         if os.path.isfile(self.__applied_file):
524             bottommost = file(self.__applied_file).readline().strip()
525             if bottommost:
526                 return self.get_patch(bottommost).get_bottom()
527         # No bottommost patch, so just return HEAD
528         return git.get_head()
529
530     def get_parent_remote(self):
531         value = config.get('branch.%s.remote' % self.get_name())
532         if value:
533             return value
534         elif 'origin' in git.remotes_list():
535             out.note(('No parent remote declared for stack "%s",'
536                       ' defaulting to "origin".' % self.get_name()),
537                      ('Consider setting "branch.%s.remote" and'
538                       ' "branch.%s.merge" with "git config".'
539                       % (self.get_name(), self.get_name())))
540             return 'origin'
541         else:
542             raise StackException, 'Cannot find a parent remote for "%s"' % self.get_name()
543
544     def __set_parent_remote(self, remote):
545         value = config.set('branch.%s.remote' % self.get_name(), remote)
546
547     def get_parent_branch(self):
548         value = config.get('branch.%s.stgit.parentbranch' % self.get_name())
549         if value:
550             return value
551         elif git.rev_parse('heads/origin'):
552             out.note(('No parent branch declared for stack "%s",'
553                       ' defaulting to "heads/origin".' % self.get_name()),
554                      ('Consider setting "branch.%s.stgit.parentbranch"'
555                       ' with "git config".' % self.get_name()))
556             return 'heads/origin'
557         else:
558             raise StackException, 'Cannot find a parent branch for "%s"' % self.get_name()
559
560     def __set_parent_branch(self, name):
561         if config.get('branch.%s.remote' % self.get_name()):
562             # Never set merge if remote is not set to avoid
563             # possibly-erroneous lookups into 'origin'
564             config.set('branch.%s.merge' % self.get_name(), name)
565         config.set('branch.%s.stgit.parentbranch' % self.get_name(), name)
566
567     def set_parent(self, remote, localbranch):
568         if localbranch:
569             if remote:
570                 self.__set_parent_remote(remote)
571             self.__set_parent_branch(localbranch)
572         # We'll enforce this later
573 #         else:
574 #             raise StackException, 'Parent branch (%s) should be specified for %s' % localbranch, self.get_name()
575
576     def __patch_is_current(self, patch):
577         return patch.get_name() == self.get_current()
578
579     def patch_applied(self, name):
580         """Return true if the patch exists in the applied list
581         """
582         return name in self.get_applied()
583
584     def patch_unapplied(self, name):
585         """Return true if the patch exists in the unapplied list
586         """
587         return name in self.get_unapplied()
588
589     def patch_hidden(self, name):
590         """Return true if the patch is hidden.
591         """
592         return name in self.get_hidden()
593
594     def patch_exists(self, name):
595         """Return true if there is a patch with the given name, false
596         otherwise."""
597         return self.patch_applied(name) or self.patch_unapplied(name) \
598                or self.patch_hidden(name)
599
600     def init(self, create_at=False, parent_remote=None, parent_branch=None):
601         """Initialises the stgit series
602         """
603         if self.is_initialised():
604             raise StackException, '%s already initialized' % self.get_name()
605         for d in [self._dir()]:
606             if os.path.exists(d):
607                 raise StackException, '%s already exists' % d
608
609         if (create_at!=False):
610             git.create_branch(self.get_name(), create_at)
611
612         os.makedirs(self.__patch_dir)
613
614         self.set_parent(parent_remote, parent_branch)
615
616         self.create_empty_field('applied')
617         self.create_empty_field('unapplied')
618         self._set_field('orig-base', git.get_head())
619
620         config.set(self.format_version_key(), str(FORMAT_VERSION))
621
622     def rename(self, to_name):
623         """Renames a series
624         """
625         to_stack = Series(to_name)
626
627         if to_stack.is_initialised():
628             raise StackException, '"%s" already exists' % to_stack.get_name()
629
630         patches = self.get_applied() + self.get_unapplied()
631
632         git.rename_branch(self.get_name(), to_name)
633
634         for patch in patches:
635             git.rename_ref('refs/patches/%s/%s' % (self.get_name(), patch),
636                            'refs/patches/%s/%s' % (to_name, patch))
637             git.rename_ref('refs/patches/%s/%s.log' % (self.get_name(), patch),
638                            'refs/patches/%s/%s.log' % (to_name, patch))
639         if os.path.isdir(self._dir()):
640             rename(os.path.join(self._basedir(), 'patches'),
641                    self.get_name(), to_stack.get_name())
642
643         # Rename the config section
644         for k in ['branch.%s', 'branch.%s.stgit']:
645             config.rename_section(k % self.get_name(), k % to_name)
646
647         self.__init__(to_name)
648
649     def clone(self, target_series):
650         """Clones a series
651         """
652         try:
653             # allow cloning of branches not under StGIT control
654             base = self.get_base()
655         except:
656             base = git.get_head()
657         Series(target_series).init(create_at = base)
658         new_series = Series(target_series)
659
660         # generate an artificial description file
661         new_series.set_description('clone of "%s"' % self.get_name())
662
663         # clone self's entire series as unapplied patches
664         try:
665             # allow cloning of branches not under StGIT control
666             applied = self.get_applied()
667             unapplied = self.get_unapplied()
668             patches = applied + unapplied
669             patches.reverse()
670         except:
671             patches = applied = unapplied = []
672         for p in patches:
673             patch = self.get_patch(p)
674             newpatch = new_series.new_patch(p, message = patch.get_description(),
675                                             can_edit = False, unapplied = True,
676                                             bottom = patch.get_bottom(),
677                                             top = patch.get_top(),
678                                             author_name = patch.get_authname(),
679                                             author_email = patch.get_authemail(),
680                                             author_date = patch.get_authdate())
681             if patch.get_log():
682                 out.info('Setting log to %s' %  patch.get_log())
683                 newpatch.set_log(patch.get_log())
684             else:
685                 out.info('No log for %s' % p)
686
687         # fast forward the cloned series to self's top
688         new_series.forward_patches(applied)
689
690         # Clone parent informations
691         value = config.get('branch.%s.remote' % self.get_name())
692         if value:
693             config.set('branch.%s.remote' % target_series, value)
694
695         value = config.get('branch.%s.merge' % self.get_name())
696         if value:
697             config.set('branch.%s.merge' % target_series, value)
698
699         value = config.get('branch.%s.stgit.parentbranch' % self.get_name())
700         if value:
701             config.set('branch.%s.stgit.parentbranch' % target_series, value)
702
703     def delete(self, force = False):
704         """Deletes an stgit series
705         """
706         if self.is_initialised():
707             patches = self.get_unapplied() + self.get_applied()
708             if not force and patches:
709                 raise StackException, \
710                       'Cannot delete: the series still contains patches'
711             for p in patches:
712                 self.get_patch(p).delete()
713
714             # remove the trash directory if any
715             if os.path.exists(self.__trash_dir):
716                 for fname in os.listdir(self.__trash_dir):
717                     os.remove(os.path.join(self.__trash_dir, fname))
718                 os.rmdir(self.__trash_dir)
719
720             # FIXME: find a way to get rid of those manual removals
721             # (move functionality to StgitObject ?)
722             if os.path.exists(self.__applied_file):
723                 os.remove(self.__applied_file)
724             if os.path.exists(self.__unapplied_file):
725                 os.remove(self.__unapplied_file)
726             if os.path.exists(self.__hidden_file):
727                 os.remove(self.__hidden_file)
728             if os.path.exists(self._dir()+'/orig-base'):
729                 os.remove(self._dir()+'/orig-base')
730
731             if not os.listdir(self.__patch_dir):
732                 os.rmdir(self.__patch_dir)
733             else:
734                 out.warn('Patch directory %s is not empty' % self.__patch_dir)
735
736             try:
737                 os.removedirs(self._dir())
738             except OSError:
739                 raise StackException('Series directory %s is not empty'
740                                      % self._dir())
741
742             try:
743                 git.delete_branch(self.get_name())
744             except GitException:
745                 out.warn('Could not delete branch "%s"' % self.get_name())
746
747         config.remove_section('branch.%s' % self.get_name())
748         config.remove_section('branch.%s.stgit' % self.get_name())
749
750     def refresh_patch(self, files = None, message = None, edit = False,
751                       show_patch = False,
752                       cache_update = True,
753                       author_name = None, author_email = None,
754                       author_date = None,
755                       committer_name = None, committer_email = None,
756                       backup = False, sign_str = None, log = 'refresh',
757                       notes = None):
758         """Generates a new commit for the topmost patch
759         """
760         patch = self.get_current_patch()
761         if not patch:
762             raise StackException, 'No patches applied'
763
764         descr = patch.get_description()
765         if not (message or descr):
766             edit = True
767             descr = ''
768         elif message:
769             descr = message
770
771         # TODO: move this out of the stgit.stack module, it is really
772         # for higher level commands to handle the user interaction
773         if not message and edit:
774             descr = edit_file(self, descr.rstrip(), \
775                               'Please edit the description for patch "%s" ' \
776                               'above.' % patch.get_name(), show_patch)
777
778         if not author_name:
779             author_name = patch.get_authname()
780         if not author_email:
781             author_email = patch.get_authemail()
782         if not author_date:
783             author_date = patch.get_authdate()
784         if not committer_name:
785             committer_name = patch.get_commname()
786         if not committer_email:
787             committer_email = patch.get_commemail()
788
789         descr = add_sign_line(descr, sign_str, committer_name, committer_email)
790
791         bottom = patch.get_bottom()
792
793         commit_id = git.commit(files = files,
794                                message = descr, parents = [bottom],
795                                cache_update = cache_update,
796                                allowempty = True,
797                                author_name = author_name,
798                                author_email = author_email,
799                                author_date = author_date,
800                                committer_name = committer_name,
801                                committer_email = committer_email)
802
803         patch.set_bottom(bottom, backup = backup)
804         patch.set_top(commit_id, backup = backup)
805         patch.set_description(descr)
806         patch.set_authname(author_name)
807         patch.set_authemail(author_email)
808         patch.set_authdate(author_date)
809         patch.set_commname(committer_name)
810         patch.set_commemail(committer_email)
811
812         if log:
813             self.log_patch(patch, log, notes)
814
815         return commit_id
816
817     def undo_refresh(self):
818         """Undo the patch boundaries changes caused by 'refresh'
819         """
820         name = self.get_current()
821         assert(name)
822
823         patch = self.get_patch(name)
824         old_bottom = patch.get_old_bottom()
825         old_top = patch.get_old_top()
826
827         # the bottom of the patch is not changed by refresh. If the
828         # old_bottom is different, there wasn't any previous 'refresh'
829         # command (probably only a 'push')
830         if old_bottom != patch.get_bottom() or old_top == patch.get_top():
831             raise StackException, 'No undo information available'
832
833         git.reset(tree_id = old_top, check_out = False)
834         if patch.restore_old_boundaries():
835             self.log_patch(patch, 'undo')
836
837     def new_patch(self, name, message = None, can_edit = True,
838                   unapplied = False, show_patch = False,
839                   top = None, bottom = None, commit = True,
840                   author_name = None, author_email = None, author_date = None,
841                   committer_name = None, committer_email = None,
842                   before_existing = False):
843         """Creates a new patch
844         """
845
846         if name != None:
847             self.__patch_name_valid(name)
848             if self.patch_exists(name):
849                 raise StackException, 'Patch "%s" already exists' % name
850
851         # TODO: move this out of the stgit.stack module, it is really
852         # for higher level commands to handle the user interaction
853         if not message and can_edit:
854             descr = edit_file(
855                 self, None,
856                 'Please enter the description for the patch above.',
857                 show_patch)
858         else:
859             descr = message
860
861         head = git.get_head()
862
863         if name == None:
864             name = make_patch_name(descr, self.patch_exists)
865
866         patch = self.get_patch(name)
867         patch.create()
868
869         if not bottom:
870             bottom = head
871         if not top:
872             top = head
873
874         patch.set_bottom(bottom)
875         patch.set_top(top)
876         patch.set_description(descr)
877         patch.set_authname(author_name)
878         patch.set_authemail(author_email)
879         patch.set_authdate(author_date)
880         patch.set_commname(committer_name)
881         patch.set_commemail(committer_email)
882
883         if before_existing:
884             insert_string(self.__applied_file, patch.get_name())
885             # no need to commit anything as the object is already
886             # present (mainly used by 'uncommit')
887             commit = False
888         elif unapplied:
889             patches = [patch.get_name()] + self.get_unapplied()
890             write_strings(self.__unapplied_file, patches)
891             set_head = False
892         else:
893             append_string(self.__applied_file, patch.get_name())
894             set_head = True
895
896         if commit:
897             # create a commit for the patch (may be empty if top == bottom);
898             # only commit on top of the current branch
899             assert(unapplied or bottom == head)
900             top_commit = git.get_commit(top)
901             commit_id = git.commit(message = descr, parents = [bottom],
902                                    cache_update = False,
903                                    tree_id = top_commit.get_tree(),
904                                    allowempty = True, set_head = set_head,
905                                    author_name = author_name,
906                                    author_email = author_email,
907                                    author_date = author_date,
908                                    committer_name = committer_name,
909                                    committer_email = committer_email)
910             # set the patch top to the new commit
911             patch.set_top(commit_id)
912
913         self.log_patch(patch, 'new')
914
915         return patch
916
917     def delete_patch(self, name):
918         """Deletes a patch
919         """
920         self.__patch_name_valid(name)
921         patch = self.get_patch(name)
922
923         if self.__patch_is_current(patch):
924             self.pop_patch(name)
925         elif self.patch_applied(name):
926             raise StackException, 'Cannot remove an applied patch, "%s", ' \
927                   'which is not current' % name
928         elif not name in self.get_unapplied():
929             raise StackException, 'Unknown patch "%s"' % name
930
931         # save the commit id to a trash file
932         write_string(os.path.join(self.__trash_dir, name), patch.get_top())
933
934         patch.delete()
935
936         unapplied = self.get_unapplied()
937         unapplied.remove(name)
938         write_strings(self.__unapplied_file, unapplied)
939
940     def forward_patches(self, names):
941         """Try to fast-forward an array of patches.
942
943         On return, patches in names[0:returned_value] have been pushed on the
944         stack. Apply the rest with push_patch
945         """
946         unapplied = self.get_unapplied()
947
948         forwarded = 0
949         top = git.get_head()
950
951         for name in names:
952             assert(name in unapplied)
953
954             patch = self.get_patch(name)
955
956             head = top
957             bottom = patch.get_bottom()
958             top = patch.get_top()
959
960             # top != bottom always since we have a commit for each patch
961             if head == bottom:
962                 # reset the backup information. No logging since the
963                 # patch hasn't changed
964                 patch.set_bottom(head, backup = True)
965                 patch.set_top(top, backup = True)
966
967             else:
968                 head_tree = git.get_commit(head).get_tree()
969                 bottom_tree = git.get_commit(bottom).get_tree()
970                 if head_tree == bottom_tree:
971                     # We must just reparent this patch and create a new commit
972                     # for it
973                     descr = patch.get_description()
974                     author_name = patch.get_authname()
975                     author_email = patch.get_authemail()
976                     author_date = patch.get_authdate()
977                     committer_name = patch.get_commname()
978                     committer_email = patch.get_commemail()
979
980                     top_tree = git.get_commit(top).get_tree()
981
982                     top = git.commit(message = descr, parents = [head],
983                                      cache_update = False,
984                                      tree_id = top_tree,
985                                      allowempty = True,
986                                      author_name = author_name,
987                                      author_email = author_email,
988                                      author_date = author_date,
989                                      committer_name = committer_name,
990                                      committer_email = committer_email)
991
992                     patch.set_bottom(head, backup = True)
993                     patch.set_top(top, backup = True)
994
995                     self.log_patch(patch, 'push(f)')
996                 else:
997                     top = head
998                     # stop the fast-forwarding, must do a real merge
999                     break
1000
1001             forwarded+=1
1002             unapplied.remove(name)
1003
1004         if forwarded == 0:
1005             return 0
1006
1007         git.switch(top)
1008
1009         append_strings(self.__applied_file, names[0:forwarded])
1010         write_strings(self.__unapplied_file, unapplied)
1011
1012         return forwarded
1013
1014     def merged_patches(self, names):
1015         """Test which patches were merged upstream by reverse-applying
1016         them in reverse order. The function returns the list of
1017         patches detected to have been applied. The state of the tree
1018         is restored to the original one
1019         """
1020         patches = [self.get_patch(name) for name in names]
1021         patches.reverse()
1022
1023         merged = []
1024         for p in patches:
1025             if git.apply_diff(p.get_top(), p.get_bottom()):
1026                 merged.append(p.get_name())
1027         merged.reverse()
1028
1029         git.reset()
1030
1031         return merged
1032
1033     def push_empty_patch(self, name):
1034         """Pushes an empty patch on the stack
1035         """
1036         unapplied = self.get_unapplied()
1037         assert(name in unapplied)
1038
1039         patch = self.get_patch(name)
1040         head = git.get_head()
1041
1042         # The top is updated by refresh_patch since we need an empty
1043         # commit
1044         patch.set_bottom(head, backup = True)
1045         patch.set_top(head, backup = True)
1046
1047         append_string(self.__applied_file, name)
1048
1049         unapplied.remove(name)
1050         write_strings(self.__unapplied_file, unapplied)
1051
1052         self.refresh_patch(cache_update = False, log = 'push(m)')
1053
1054     def push_patch(self, name):
1055         """Pushes a patch on the stack
1056         """
1057         unapplied = self.get_unapplied()
1058         assert(name in unapplied)
1059
1060         patch = self.get_patch(name)
1061
1062         head = git.get_head()
1063         bottom = patch.get_bottom()
1064         top = patch.get_top()
1065         # top != bottom always since we have a commit for each patch
1066
1067         if head == bottom:
1068             # A fast-forward push. Just reset the backup
1069             # information. No need for logging
1070             patch.set_bottom(bottom, backup = True)
1071             patch.set_top(top, backup = True)
1072
1073             git.switch(top)
1074             append_string(self.__applied_file, name)
1075
1076             unapplied.remove(name)
1077             write_strings(self.__unapplied_file, unapplied)
1078             return False
1079
1080         # Need to create a new commit an merge in the old patch
1081         ex = None
1082         modified = False
1083
1084         # new patch needs to be refreshed.
1085         # The current patch is empty after merge.
1086         patch.set_bottom(head, backup = True)
1087         patch.set_top(head, backup = True)
1088
1089         # Try the fast applying first. If this fails, fall back to the
1090         # three-way merge
1091         if not git.apply_diff(bottom, top):
1092             # if git.apply_diff() fails, the patch requires a diff3
1093             # merge and can be reported as modified
1094             modified = True
1095
1096             # merge can fail but the patch needs to be pushed
1097             try:
1098                 git.merge(bottom, head, top, recursive = True)
1099             except git.GitException, ex:
1100                 out.error('The merge failed during "push".',
1101                           'Use "refresh" after fixing the conflicts or'
1102                           ' revert the operation with "push --undo".')
1103
1104         append_string(self.__applied_file, name)
1105
1106         unapplied.remove(name)
1107         write_strings(self.__unapplied_file, unapplied)
1108
1109         if not ex:
1110             # if the merge was OK and no conflicts, just refresh the patch
1111             # The GIT cache was already updated by the merge operation
1112             if modified:
1113                 log = 'push(m)'
1114             else:
1115                 log = 'push'
1116             self.refresh_patch(cache_update = False, log = log)
1117         else:
1118             # we store the correctly merged files only for
1119             # tracking the conflict history. Note that the
1120             # git.merge() operations should always leave the index
1121             # in a valid state (i.e. only stage 0 files)
1122             self.refresh_patch(cache_update = False, log = 'push(c)')
1123             raise StackException, str(ex)
1124
1125         return modified
1126
1127     def undo_push(self):
1128         name = self.get_current()
1129         assert(name)
1130
1131         patch = self.get_patch(name)
1132         old_bottom = patch.get_old_bottom()
1133         old_top = patch.get_old_top()
1134
1135         # the top of the patch is changed by a push operation only
1136         # together with the bottom (otherwise the top was probably
1137         # modified by 'refresh'). If they are both unchanged, there
1138         # was a fast forward
1139         if old_bottom == patch.get_bottom() and old_top != patch.get_top():
1140             raise StackException, 'No undo information available'
1141
1142         git.reset()
1143         self.pop_patch(name)
1144         ret = patch.restore_old_boundaries()
1145         if ret:
1146             self.log_patch(patch, 'undo')
1147
1148         return ret
1149
1150     def pop_patch(self, name, keep = False):
1151         """Pops the top patch from the stack
1152         """
1153         applied = self.get_applied()
1154         applied.reverse()
1155         assert(name in applied)
1156
1157         patch = self.get_patch(name)
1158
1159         if git.get_head_file() == self.get_name():
1160             if keep and not git.apply_diff(git.get_head(), patch.get_bottom()):
1161                 raise StackException(
1162                     'Failed to pop patches while preserving the local changes')
1163             git.switch(patch.get_bottom(), keep)
1164         else:
1165             git.set_branch(self.get_name(), patch.get_bottom())
1166
1167         # save the new applied list
1168         idx = applied.index(name) + 1
1169
1170         popped = applied[:idx]
1171         popped.reverse()
1172         unapplied = popped + self.get_unapplied()
1173         write_strings(self.__unapplied_file, unapplied)
1174
1175         del applied[:idx]
1176         applied.reverse()
1177         write_strings(self.__applied_file, applied)
1178
1179     def empty_patch(self, name):
1180         """Returns True if the patch is empty
1181         """
1182         self.__patch_name_valid(name)
1183         patch = self.get_patch(name)
1184         bottom = patch.get_bottom()
1185         top = patch.get_top()
1186
1187         if bottom == top:
1188             return True
1189         elif git.get_commit(top).get_tree() \
1190                  == git.get_commit(bottom).get_tree():
1191             return True
1192
1193         return False
1194
1195     def rename_patch(self, oldname, newname):
1196         self.__patch_name_valid(newname)
1197
1198         applied = self.get_applied()
1199         unapplied = self.get_unapplied()
1200
1201         if oldname == newname:
1202             raise StackException, '"To" name and "from" name are the same'
1203
1204         if newname in applied or newname in unapplied:
1205             raise StackException, 'Patch "%s" already exists' % newname
1206
1207         if oldname in unapplied:
1208             self.get_patch(oldname).rename(newname)
1209             unapplied[unapplied.index(oldname)] = newname
1210             write_strings(self.__unapplied_file, unapplied)
1211         elif oldname in applied:
1212             self.get_patch(oldname).rename(newname)
1213
1214             applied[applied.index(oldname)] = newname
1215             write_strings(self.__applied_file, applied)
1216         else:
1217             raise StackException, 'Unknown patch "%s"' % oldname
1218
1219     def log_patch(self, patch, message, notes = None):
1220         """Generate a log commit for a patch
1221         """
1222         top = git.get_commit(patch.get_top())
1223         old_log = patch.get_log()
1224
1225         if message is None:
1226             # replace the current log entry
1227             if not old_log:
1228                 raise StackException, \
1229                       'No log entry to annotate for patch "%s"' \
1230                       % patch.get_name()
1231             replace = True
1232             log_commit = git.get_commit(old_log)
1233             msg = log_commit.get_log().split('\n')[0]
1234             log_parent = log_commit.get_parent()
1235             if log_parent:
1236                 parents = [log_parent]
1237             else:
1238                 parents = []
1239         else:
1240             # generate a new log entry
1241             replace = False
1242             msg = '%s\t%s' % (message, top.get_id_hash())
1243             if old_log:
1244                 parents = [old_log]
1245             else:
1246                 parents = []
1247
1248         if notes:
1249             msg += '\n\n' + notes
1250
1251         log = git.commit(message = msg, parents = parents,
1252                          cache_update = False, tree_id = top.get_tree(),
1253                          allowempty = True)
1254         patch.set_log(log)
1255
1256     def hide_patch(self, name):
1257         """Add the patch to the hidden list.
1258         """
1259         unapplied = self.get_unapplied()
1260         if name not in unapplied:
1261             # keep the checking order for backward compatibility with
1262             # the old hidden patches functionality
1263             if self.patch_applied(name):
1264                 raise StackException, 'Cannot hide applied patch "%s"' % name
1265             elif self.patch_hidden(name):
1266                 raise StackException, 'Patch "%s" already hidden' % name
1267             else:
1268                 raise StackException, 'Unknown patch "%s"' % name
1269
1270         if not self.patch_hidden(name):
1271             # check needed for backward compatibility with the old
1272             # hidden patches functionality
1273             append_string(self.__hidden_file, name)
1274
1275         unapplied.remove(name)
1276         write_strings(self.__unapplied_file, unapplied)
1277
1278     def unhide_patch(self, name):
1279         """Remove the patch from the hidden list.
1280         """
1281         hidden = self.get_hidden()
1282         if not name in hidden:
1283             if self.patch_applied(name) or self.patch_unapplied(name):
1284                 raise StackException, 'Patch "%s" not hidden' % name
1285             else:
1286                 raise StackException, 'Unknown patch "%s"' % name
1287
1288         hidden.remove(name)
1289         write_strings(self.__hidden_file, hidden)
1290
1291         if not self.patch_applied(name) and not self.patch_unapplied(name):
1292             # check needed for backward compatibility with the old
1293             # hidden patches functionality
1294             append_string(self.__unapplied_file, name)