chiark / gitweb /
Make Series.refresh_patch automatically save the undo information
[stgit] / stgit / stack.py
1 """Basic quilt-like functionality
2 """
3
4 __copyright__ = """
5 Copyright (C) 2005, Catalin Marinas <catalin.marinas@gmail.com>
6
7 This program is free software; you can redistribute it and/or modify
8 it under the terms of the GNU General Public License version 2 as
9 published by the Free Software Foundation.
10
11 This program is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14 GNU General Public License for more details.
15
16 You should have received a copy of the GNU General Public License
17 along with this program; if not, write to the Free Software
18 Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
19 """
20
21 import sys, os, re
22 from email.Utils import formatdate
23
24 from stgit.exception import *
25 from stgit.utils import *
26 from stgit.out import *
27 from stgit.run import *
28 from stgit import git, basedir, templates
29 from stgit.config import config
30 from shutil import copyfile
31
32
33 # stack exception class
34 class StackException(StgException):
35     pass
36
37 class FilterUntil:
38     def __init__(self):
39         self.should_print = True
40     def __call__(self, x, until_test, prefix):
41         if until_test(x):
42             self.should_print = False
43         if self.should_print:
44             return x[0:len(prefix)] != prefix
45         return False
46
47 #
48 # Functions
49 #
50 __comment_prefix = 'STG:'
51 __patch_prefix = 'STG_PATCH:'
52
53 def __clean_comments(f):
54     """Removes lines marked for status in a commit file
55     """
56     f.seek(0)
57
58     # remove status-prefixed lines
59     lines = f.readlines()
60
61     patch_filter = FilterUntil()
62     until_test = lambda t: t == (__patch_prefix + '\n')
63     lines = [l for l in lines if patch_filter(l, until_test, __comment_prefix)]
64
65     # remove empty lines at the end
66     while len(lines) != 0 and lines[-1] == '\n':
67         del lines[-1]
68
69     f.seek(0); f.truncate()
70     f.writelines(lines)
71
72 # TODO: move this out of the stgit.stack module, it is really for
73 # higher level commands to handle the user interaction
74 def edit_file(series, line, comment, show_patch = True):
75     fname = '.stgitmsg.txt'
76     tmpl = templates.get_template('patchdescr.tmpl')
77
78     f = file(fname, 'w+')
79     if line:
80         print >> f, line
81     elif tmpl:
82         print >> f, tmpl,
83     else:
84         print >> f
85     print >> f, __comment_prefix, comment
86     print >> f, __comment_prefix, \
87           'Lines prefixed with "%s" will be automatically removed.' \
88           % __comment_prefix
89     print >> f, __comment_prefix, \
90           'Trailing empty lines will be automatically removed.'
91
92     if show_patch:
93        print >> f, __patch_prefix
94        # series.get_patch(series.get_current()).get_top()
95        diff_str = git.diff(rev1 = series.get_patch(series.get_current()).get_bottom())
96        f.write(diff_str)
97
98     #Vim modeline must be near the end.
99     print >> f, __comment_prefix, 'vi: set textwidth=75 filetype=diff nobackup:'
100     f.close()
101
102     call_editor(fname)
103
104     f = file(fname, 'r+')
105
106     __clean_comments(f)
107     f.seek(0)
108     result = f.read()
109
110     f.close()
111     os.remove(fname)
112
113     return result
114
115 #
116 # Classes
117 #
118
119 class StgitObject:
120     """An object with stgit-like properties stored as files in a directory
121     """
122     def _set_dir(self, dir):
123         self.__dir = dir
124     def _dir(self):
125         return self.__dir
126
127     def create_empty_field(self, name):
128         create_empty_file(os.path.join(self.__dir, name))
129
130     def _get_field(self, name, multiline = False):
131         id_file = os.path.join(self.__dir, name)
132         if os.path.isfile(id_file):
133             line = read_string(id_file, multiline)
134             if line == '':
135                 return None
136             else:
137                 return line
138         else:
139             return None
140
141     def _set_field(self, name, value, multiline = False):
142         fname = os.path.join(self.__dir, name)
143         if value and value != '':
144             write_string(fname, value, multiline)
145         elif os.path.isfile(fname):
146             os.remove(fname)
147
148
149 class Patch(StgitObject):
150     """Basic patch implementation
151     """
152     def __init_refs(self):
153         self.__top_ref = self.__refs_base + '/' + self.__name
154         self.__log_ref = self.__top_ref + '.log'
155
156     def __init__(self, name, series_dir, refs_base):
157         self.__series_dir = series_dir
158         self.__name = name
159         self._set_dir(os.path.join(self.__series_dir, self.__name))
160         self.__refs_base = refs_base
161         self.__init_refs()
162
163     def create(self):
164         os.mkdir(self._dir())
165         self.create_empty_field('bottom')
166         self.create_empty_field('top')
167
168     def delete(self, keep_log = False):
169         if os.path.isdir(self._dir()):
170             for f in os.listdir(self._dir()):
171                 os.remove(os.path.join(self._dir(), f))
172             os.rmdir(self._dir())
173         else:
174             out.warn('Patch directory "%s" does not exist' % self._dir())
175         try:
176             # the reference might not exist if the repository was corrupted
177             git.delete_ref(self.__top_ref)
178         except git.GitException, e:
179             out.warn(str(e))
180         if not keep_log and git.ref_exists(self.__log_ref):
181             git.delete_ref(self.__log_ref)
182
183     def get_name(self):
184         return self.__name
185
186     def rename(self, newname):
187         olddir = self._dir()
188         old_top_ref = self.__top_ref
189         old_log_ref = self.__log_ref
190         self.__name = newname
191         self._set_dir(os.path.join(self.__series_dir, self.__name))
192         self.__init_refs()
193
194         git.rename_ref(old_top_ref, self.__top_ref)
195         if git.ref_exists(old_log_ref):
196             git.rename_ref(old_log_ref, self.__log_ref)
197         os.rename(olddir, self._dir())
198
199     def __update_top_ref(self, ref):
200         git.set_ref(self.__top_ref, ref)
201
202     def __update_log_ref(self, ref):
203         git.set_ref(self.__log_ref, ref)
204
205     def update_top_ref(self):
206         top = self.get_top()
207         if top:
208             self.__update_top_ref(top)
209
210     def get_old_bottom(self):
211         return self._get_field('bottom.old')
212
213     def get_bottom(self):
214         return self._get_field('bottom')
215
216     def set_bottom(self, value, backup = False):
217         if backup:
218             curr = self._get_field('bottom')
219             self._set_field('bottom.old', curr)
220         self._set_field('bottom', value)
221
222     def get_old_top(self):
223         return self._get_field('top.old')
224
225     def get_top(self):
226         return self._get_field('top')
227
228     def set_top(self, value, backup = False):
229         if backup:
230             curr = self._get_field('top')
231             self._set_field('top.old', curr)
232         self._set_field('top', value)
233         self.__update_top_ref(value)
234
235     def restore_old_boundaries(self):
236         bottom = self._get_field('bottom.old')
237         top = self._get_field('top.old')
238
239         if top and bottom:
240             self._set_field('bottom', bottom)
241             self._set_field('top', top)
242             self.__update_top_ref(top)
243             return True
244         else:
245             return False
246
247     def get_description(self):
248         return self._get_field('description', True)
249
250     def set_description(self, line):
251         self._set_field('description', line, True)
252
253     def get_authname(self):
254         return self._get_field('authname')
255
256     def set_authname(self, name):
257         self._set_field('authname', name or git.author().name)
258
259     def get_authemail(self):
260         return self._get_field('authemail')
261
262     def set_authemail(self, email):
263         self._set_field('authemail', email or git.author().email)
264
265     def get_authdate(self):
266         date = self._get_field('authdate')
267         if not date:
268             return date
269
270         if re.match('[0-9]+\s+[+-][0-9]+', date):
271             # Unix time (seconds) + time zone
272             secs_tz = date.split()
273             date = formatdate(int(secs_tz[0]))[:-5] + secs_tz[1]
274
275         return date
276
277     def set_authdate(self, date):
278         self._set_field('authdate', date or git.author().date)
279
280     def get_commname(self):
281         return self._get_field('commname')
282
283     def set_commname(self, name):
284         self._set_field('commname', name or git.committer().name)
285
286     def get_commemail(self):
287         return self._get_field('commemail')
288
289     def set_commemail(self, email):
290         self._set_field('commemail', email or git.committer().email)
291
292     def get_log(self):
293         return self._get_field('log')
294
295     def set_log(self, value, backup = False):
296         self._set_field('log', value)
297         self.__update_log_ref(value)
298
299 # The current StGIT metadata format version.
300 FORMAT_VERSION = 2
301
302 class PatchSet(StgitObject):
303     def __init__(self, name = None):
304         try:
305             if name:
306                 self.set_name (name)
307             else:
308                 self.set_name (git.get_head_file())
309             self.__base_dir = basedir.get()
310         except git.GitException, ex:
311             raise StackException, 'GIT tree not initialised: %s' % ex
312
313         self._set_dir(os.path.join(self.__base_dir, 'patches', self.get_name()))
314
315     def get_name(self):
316         return self.__name
317     def set_name(self, name):
318         self.__name = name
319
320     def _basedir(self):
321         return self.__base_dir
322
323     def get_head(self):
324         """Return the head of the branch
325         """
326         crt = self.get_current_patch()
327         if crt:
328             return crt.get_top()
329         else:
330             return self.get_base()
331
332     def get_protected(self):
333         return os.path.isfile(os.path.join(self._dir(), 'protected'))
334
335     def protect(self):
336         protect_file = os.path.join(self._dir(), 'protected')
337         if not os.path.isfile(protect_file):
338             create_empty_file(protect_file)
339
340     def unprotect(self):
341         protect_file = os.path.join(self._dir(), 'protected')
342         if os.path.isfile(protect_file):
343             os.remove(protect_file)
344
345     def __branch_descr(self):
346         return 'branch.%s.description' % self.get_name()
347
348     def get_description(self):
349         return config.get(self.__branch_descr()) or ''
350
351     def set_description(self, line):
352         if line:
353             config.set(self.__branch_descr(), line)
354         else:
355             config.unset(self.__branch_descr())
356
357     def head_top_equal(self):
358         """Return true if the head and the top are the same
359         """
360         crt = self.get_current_patch()
361         if not crt:
362             # we don't care, no patches applied
363             return True
364         return git.get_head() == crt.get_top()
365
366     def is_initialised(self):
367         """Checks if series is already initialised
368         """
369         return bool(config.get(self.format_version_key()))
370
371
372 def shortlog(patches):
373     log = ''.join(Run('git', 'log', '--pretty=short',
374                       p.get_top(), '^%s' % p.get_bottom()).raw_output()
375                   for p in patches)
376     return Run('git', 'shortlog').raw_input(log).raw_output()
377
378 class Series(PatchSet):
379     """Class including the operations on series
380     """
381     def __init__(self, name = None):
382         """Takes a series name as the parameter.
383         """
384         PatchSet.__init__(self, name)
385
386         # Update the branch to the latest format version if it is
387         # initialized, but don't touch it if it isn't.
388         self.update_to_current_format_version()
389
390         self.__refs_base = 'refs/patches/%s' % self.get_name()
391
392         self.__applied_file = os.path.join(self._dir(), 'applied')
393         self.__unapplied_file = os.path.join(self._dir(), 'unapplied')
394         self.__hidden_file = os.path.join(self._dir(), 'hidden')
395
396         # where this series keeps its patches
397         self.__patch_dir = os.path.join(self._dir(), 'patches')
398
399         # trash directory
400         self.__trash_dir = os.path.join(self._dir(), 'trash')
401
402     def format_version_key(self):
403         return 'branch.%s.stgit.stackformatversion' % self.get_name()
404
405     def update_to_current_format_version(self):
406         """Update a potentially older StGIT directory structure to the
407         latest version. Note: This function should depend as little as
408         possible on external functions that may change during a format
409         version bump, since it must remain able to process older formats."""
410
411         branch_dir = os.path.join(self._basedir(), 'patches', self.get_name())
412         def get_format_version():
413             """Return the integer format version number, or None if the
414             branch doesn't have any StGIT metadata at all, of any version."""
415             fv = config.get(self.format_version_key())
416             ofv = config.get('branch.%s.stgitformatversion' % self.get_name())
417             if fv:
418                 # Great, there's an explicitly recorded format version
419                 # number, which means that the branch is initialized and
420                 # of that exact version.
421                 return int(fv)
422             elif ofv:
423                 # Old name for the version info, upgrade it
424                 config.set(self.format_version_key(), ofv)
425                 config.unset('branch.%s.stgitformatversion' % self.get_name())
426                 return int(ofv)
427             elif os.path.isdir(os.path.join(branch_dir, 'patches')):
428                 # There's a .git/patches/<branch>/patches dirctory, which
429                 # means this is an initialized version 1 branch.
430                 return 1
431             elif os.path.isdir(branch_dir):
432                 # There's a .git/patches/<branch> directory, which means
433                 # this is an initialized version 0 branch.
434                 return 0
435             else:
436                 # The branch doesn't seem to be initialized at all.
437                 return None
438         def set_format_version(v):
439             out.info('Upgraded branch %s to format version %d' % (self.get_name(), v))
440             config.set(self.format_version_key(), '%d' % v)
441         def mkdir(d):
442             if not os.path.isdir(d):
443                 os.makedirs(d)
444         def rm(f):
445             if os.path.exists(f):
446                 os.remove(f)
447         def rm_ref(ref):
448             if git.ref_exists(ref):
449                 git.delete_ref(ref)
450
451         # Update 0 -> 1.
452         if get_format_version() == 0:
453             mkdir(os.path.join(branch_dir, 'trash'))
454             patch_dir = os.path.join(branch_dir, 'patches')
455             mkdir(patch_dir)
456             refs_base = 'refs/patches/%s' % self.get_name()
457             for patch in (file(os.path.join(branch_dir, 'unapplied')).readlines()
458                           + file(os.path.join(branch_dir, 'applied')).readlines()):
459                 patch = patch.strip()
460                 os.rename(os.path.join(branch_dir, patch),
461                           os.path.join(patch_dir, patch))
462                 Patch(patch, patch_dir, refs_base).update_top_ref()
463             set_format_version(1)
464
465         # Update 1 -> 2.
466         if get_format_version() == 1:
467             desc_file = os.path.join(branch_dir, 'description')
468             if os.path.isfile(desc_file):
469                 desc = read_string(desc_file)
470                 if desc:
471                     config.set('branch.%s.description' % self.get_name(), desc)
472                 rm(desc_file)
473             rm(os.path.join(branch_dir, 'current'))
474             rm_ref('refs/bases/%s' % self.get_name())
475             set_format_version(2)
476
477         # Make sure we're at the latest version.
478         if not get_format_version() in [None, FORMAT_VERSION]:
479             raise StackException('Branch %s is at format version %d, expected %d'
480                                  % (self.get_name(), get_format_version(), FORMAT_VERSION))
481
482     def __patch_name_valid(self, name):
483         """Raise an exception if the patch name is not valid.
484         """
485         if not name or re.search('[^\w.-]', name):
486             raise StackException, 'Invalid patch name: "%s"' % name
487
488     def get_patch(self, name):
489         """Return a Patch object for the given name
490         """
491         return Patch(name, self.__patch_dir, self.__refs_base)
492
493     def get_current_patch(self):
494         """Return a Patch object representing the topmost patch, or
495         None if there is no such patch."""
496         crt = self.get_current()
497         if not crt:
498             return None
499         return self.get_patch(crt)
500
501     def get_current(self):
502         """Return the name of the topmost patch, or None if there is
503         no such patch."""
504         try:
505             applied = self.get_applied()
506         except StackException:
507             # No "applied" file: branch is not initialized.
508             return None
509         try:
510             return applied[-1]
511         except IndexError:
512             # No patches applied.
513             return None
514
515     def get_applied(self):
516         if not os.path.isfile(self.__applied_file):
517             raise StackException, 'Branch "%s" not initialised' % self.get_name()
518         return read_strings(self.__applied_file)
519
520     def set_applied(self, applied):
521         write_strings(self.__applied_file, applied)
522
523     def get_unapplied(self):
524         if not os.path.isfile(self.__unapplied_file):
525             raise StackException, 'Branch "%s" not initialised' % self.get_name()
526         return read_strings(self.__unapplied_file)
527
528     def set_unapplied(self, unapplied):
529         write_strings(self.__unapplied_file, unapplied)
530
531     def get_hidden(self):
532         if not os.path.isfile(self.__hidden_file):
533             return []
534         return read_strings(self.__hidden_file)
535
536     def get_base(self):
537         # Return the parent of the bottommost patch, if there is one.
538         if os.path.isfile(self.__applied_file):
539             bottommost = file(self.__applied_file).readline().strip()
540             if bottommost:
541                 return self.get_patch(bottommost).get_bottom()
542         # No bottommost patch, so just return HEAD
543         return git.get_head()
544
545     def get_parent_remote(self):
546         value = config.get('branch.%s.remote' % self.get_name())
547         if value:
548             return value
549         elif 'origin' in git.remotes_list():
550             out.note(('No parent remote declared for stack "%s",'
551                       ' defaulting to "origin".' % self.get_name()),
552                      ('Consider setting "branch.%s.remote" and'
553                       ' "branch.%s.merge" with "git config".'
554                       % (self.get_name(), self.get_name())))
555             return 'origin'
556         else:
557             raise StackException, 'Cannot find a parent remote for "%s"' % self.get_name()
558
559     def __set_parent_remote(self, remote):
560         value = config.set('branch.%s.remote' % self.get_name(), remote)
561
562     def get_parent_branch(self):
563         value = config.get('branch.%s.stgit.parentbranch' % self.get_name())
564         if value:
565             return value
566         elif git.rev_parse('heads/origin'):
567             out.note(('No parent branch declared for stack "%s",'
568                       ' defaulting to "heads/origin".' % self.get_name()),
569                      ('Consider setting "branch.%s.stgit.parentbranch"'
570                       ' with "git config".' % self.get_name()))
571             return 'heads/origin'
572         else:
573             raise StackException, 'Cannot find a parent branch for "%s"' % self.get_name()
574
575     def __set_parent_branch(self, name):
576         if config.get('branch.%s.remote' % self.get_name()):
577             # Never set merge if remote is not set to avoid
578             # possibly-erroneous lookups into 'origin'
579             config.set('branch.%s.merge' % self.get_name(), name)
580         config.set('branch.%s.stgit.parentbranch' % self.get_name(), name)
581
582     def set_parent(self, remote, localbranch):
583         if localbranch:
584             if remote:
585                 self.__set_parent_remote(remote)
586             self.__set_parent_branch(localbranch)
587         # We'll enforce this later
588 #         else:
589 #             raise StackException, 'Parent branch (%s) should be specified for %s' % localbranch, self.get_name()
590
591     def __patch_is_current(self, patch):
592         return patch.get_name() == self.get_current()
593
594     def patch_applied(self, name):
595         """Return true if the patch exists in the applied list
596         """
597         return name in self.get_applied()
598
599     def patch_unapplied(self, name):
600         """Return true if the patch exists in the unapplied list
601         """
602         return name in self.get_unapplied()
603
604     def patch_hidden(self, name):
605         """Return true if the patch is hidden.
606         """
607         return name in self.get_hidden()
608
609     def patch_exists(self, name):
610         """Return true if there is a patch with the given name, false
611         otherwise."""
612         return self.patch_applied(name) or self.patch_unapplied(name) \
613                or self.patch_hidden(name)
614
615     def init(self, create_at=False, parent_remote=None, parent_branch=None):
616         """Initialises the stgit series
617         """
618         if self.is_initialised():
619             raise StackException, '%s already initialized' % self.get_name()
620         for d in [self._dir()]:
621             if os.path.exists(d):
622                 raise StackException, '%s already exists' % d
623
624         if (create_at!=False):
625             git.create_branch(self.get_name(), create_at)
626
627         os.makedirs(self.__patch_dir)
628
629         self.set_parent(parent_remote, parent_branch)
630
631         self.create_empty_field('applied')
632         self.create_empty_field('unapplied')
633
634         config.set(self.format_version_key(), str(FORMAT_VERSION))
635
636     def rename(self, to_name):
637         """Renames a series
638         """
639         to_stack = Series(to_name)
640
641         if to_stack.is_initialised():
642             raise StackException, '"%s" already exists' % to_stack.get_name()
643
644         patches = self.get_applied() + self.get_unapplied()
645
646         git.rename_branch(self.get_name(), to_name)
647
648         for patch in patches:
649             git.rename_ref('refs/patches/%s/%s' % (self.get_name(), patch),
650                            'refs/patches/%s/%s' % (to_name, patch))
651             git.rename_ref('refs/patches/%s/%s.log' % (self.get_name(), patch),
652                            'refs/patches/%s/%s.log' % (to_name, patch))
653         if os.path.isdir(self._dir()):
654             rename(os.path.join(self._basedir(), 'patches'),
655                    self.get_name(), to_stack.get_name())
656
657         # Rename the config section
658         for k in ['branch.%s', 'branch.%s.stgit']:
659             config.rename_section(k % self.get_name(), k % to_name)
660
661         self.__init__(to_name)
662
663     def clone(self, target_series):
664         """Clones a series
665         """
666         try:
667             # allow cloning of branches not under StGIT control
668             base = self.get_base()
669         except:
670             base = git.get_head()
671         Series(target_series).init(create_at = base)
672         new_series = Series(target_series)
673
674         # generate an artificial description file
675         new_series.set_description('clone of "%s"' % self.get_name())
676
677         # clone self's entire series as unapplied patches
678         try:
679             # allow cloning of branches not under StGIT control
680             applied = self.get_applied()
681             unapplied = self.get_unapplied()
682             patches = applied + unapplied
683             patches.reverse()
684         except:
685             patches = applied = unapplied = []
686         for p in patches:
687             patch = self.get_patch(p)
688             newpatch = new_series.new_patch(p, message = patch.get_description(),
689                                             can_edit = False, unapplied = True,
690                                             bottom = patch.get_bottom(),
691                                             top = patch.get_top(),
692                                             author_name = patch.get_authname(),
693                                             author_email = patch.get_authemail(),
694                                             author_date = patch.get_authdate())
695             if patch.get_log():
696                 out.info('Setting log to %s' %  patch.get_log())
697                 newpatch.set_log(patch.get_log())
698             else:
699                 out.info('No log for %s' % p)
700
701         # fast forward the cloned series to self's top
702         new_series.forward_patches(applied)
703
704         # Clone parent informations
705         value = config.get('branch.%s.remote' % self.get_name())
706         if value:
707             config.set('branch.%s.remote' % target_series, value)
708
709         value = config.get('branch.%s.merge' % self.get_name())
710         if value:
711             config.set('branch.%s.merge' % target_series, value)
712
713         value = config.get('branch.%s.stgit.parentbranch' % self.get_name())
714         if value:
715             config.set('branch.%s.stgit.parentbranch' % target_series, value)
716
717     def delete(self, force = False):
718         """Deletes an stgit series
719         """
720         if self.is_initialised():
721             patches = self.get_unapplied() + self.get_applied()
722             if not force and patches:
723                 raise StackException, \
724                       'Cannot delete: the series still contains patches'
725             for p in patches:
726                 self.get_patch(p).delete()
727
728             # remove the trash directory if any
729             if os.path.exists(self.__trash_dir):
730                 for fname in os.listdir(self.__trash_dir):
731                     os.remove(os.path.join(self.__trash_dir, fname))
732                 os.rmdir(self.__trash_dir)
733
734             # FIXME: find a way to get rid of those manual removals
735             # (move functionality to StgitObject ?)
736             if os.path.exists(self.__applied_file):
737                 os.remove(self.__applied_file)
738             if os.path.exists(self.__unapplied_file):
739                 os.remove(self.__unapplied_file)
740             if os.path.exists(self.__hidden_file):
741                 os.remove(self.__hidden_file)
742             if os.path.exists(self._dir()+'/orig-base'):
743                 os.remove(self._dir()+'/orig-base')
744
745             if not os.listdir(self.__patch_dir):
746                 os.rmdir(self.__patch_dir)
747             else:
748                 out.warn('Patch directory %s is not empty' % self.__patch_dir)
749
750             try:
751                 os.removedirs(self._dir())
752             except OSError:
753                 raise StackException('Series directory %s is not empty'
754                                      % self._dir())
755
756             try:
757                 git.delete_branch(self.get_name())
758             except GitException:
759                 out.warn('Could not delete branch "%s"' % self.get_name())
760
761         config.remove_section('branch.%s' % self.get_name())
762         config.remove_section('branch.%s.stgit' % self.get_name())
763
764     def refresh_patch(self, files = None, message = None, edit = False,
765                       show_patch = False,
766                       cache_update = True,
767                       author_name = None, author_email = None,
768                       author_date = None,
769                       committer_name = None, committer_email = None,
770                       backup = True, sign_str = None, log = 'refresh',
771                       notes = None, bottom = None):
772         """Generates a new commit for the topmost patch
773         """
774         patch = self.get_current_patch()
775         if not patch:
776             raise StackException, 'No patches applied'
777
778         descr = patch.get_description()
779         if not (message or descr):
780             edit = True
781             descr = ''
782         elif message:
783             descr = message
784
785         # TODO: move this out of the stgit.stack module, it is really
786         # for higher level commands to handle the user interaction
787         if not message and edit:
788             descr = edit_file(self, descr.rstrip(), \
789                               'Please edit the description for patch "%s" ' \
790                               'above.' % patch.get_name(), show_patch)
791
792         if not author_name:
793             author_name = patch.get_authname()
794         if not author_email:
795             author_email = patch.get_authemail()
796         if not author_date:
797             author_date = patch.get_authdate()
798         if not committer_name:
799             committer_name = patch.get_commname()
800         if not committer_email:
801             committer_email = patch.get_commemail()
802
803         descr = add_sign_line(descr, sign_str, committer_name, committer_email)
804
805         if not bottom:
806             bottom = patch.get_bottom()
807
808         commit_id = git.commit(files = files,
809                                message = descr, parents = [bottom],
810                                cache_update = cache_update,
811                                allowempty = True,
812                                author_name = author_name,
813                                author_email = author_email,
814                                author_date = author_date,
815                                committer_name = committer_name,
816                                committer_email = committer_email)
817
818         patch.set_bottom(bottom, backup = backup)
819         patch.set_top(commit_id, backup = backup)
820         patch.set_description(descr)
821         patch.set_authname(author_name)
822         patch.set_authemail(author_email)
823         patch.set_authdate(author_date)
824         patch.set_commname(committer_name)
825         patch.set_commemail(committer_email)
826
827         if log:
828             self.log_patch(patch, log, notes)
829
830         return commit_id
831
832     def undo_refresh(self):
833         """Undo the patch boundaries changes caused by 'refresh'
834         """
835         name = self.get_current()
836         assert(name)
837
838         patch = self.get_patch(name)
839         old_bottom = patch.get_old_bottom()
840         old_top = patch.get_old_top()
841
842         # the bottom of the patch is not changed by refresh. If the
843         # old_bottom is different, there wasn't any previous 'refresh'
844         # command (probably only a 'push')
845         if old_bottom != patch.get_bottom() or old_top == patch.get_top():
846             raise StackException, 'No undo information available'
847
848         git.reset(tree_id = old_top, check_out = False)
849         if patch.restore_old_boundaries():
850             self.log_patch(patch, 'undo')
851
852     def new_patch(self, name, message = None, can_edit = True,
853                   unapplied = False, show_patch = False,
854                   top = None, bottom = None, commit = True,
855                   author_name = None, author_email = None, author_date = None,
856                   committer_name = None, committer_email = None,
857                   before_existing = False, sign_str = None):
858         """Creates a new patch, either pointing to an existing commit object,
859         or by creating a new commit object.
860         """
861
862         assert commit or (top and bottom)
863         assert not before_existing or (top and bottom)
864         assert not (commit and before_existing)
865         assert (top and bottom) or (not top and not bottom)
866         assert commit or (not top or (bottom == git.get_commit(top).get_parent()))
867
868         if name != None:
869             self.__patch_name_valid(name)
870             if self.patch_exists(name):
871                 raise StackException, 'Patch "%s" already exists' % name
872
873         # TODO: move this out of the stgit.stack module, it is really
874         # for higher level commands to handle the user interaction
875         def sign(msg):
876             return add_sign_line(msg, sign_str,
877                                  committer_name or git.committer().name,
878                                  committer_email or git.committer().email)
879         if not message and can_edit:
880             descr = edit_file(
881                 self, sign(''),
882                 'Please enter the description for the patch above.',
883                 show_patch)
884         else:
885             descr = sign(message)
886
887         head = git.get_head()
888
889         if name == None:
890             name = make_patch_name(descr, self.patch_exists)
891
892         patch = self.get_patch(name)
893         patch.create()
894
895         patch.set_description(descr)
896         patch.set_authname(author_name)
897         patch.set_authemail(author_email)
898         patch.set_authdate(author_date)
899         patch.set_commname(committer_name)
900         patch.set_commemail(committer_email)
901
902         if before_existing:
903             insert_string(self.__applied_file, patch.get_name())
904         elif unapplied:
905             patches = [patch.get_name()] + self.get_unapplied()
906             write_strings(self.__unapplied_file, patches)
907             set_head = False
908         else:
909             append_string(self.__applied_file, patch.get_name())
910             set_head = True
911
912         if commit:
913             if top:
914                 top_commit = git.get_commit(top)
915             else:
916                 bottom = head
917                 top_commit = git.get_commit(head)
918
919             # create a commit for the patch (may be empty if top == bottom);
920             # only commit on top of the current branch
921             assert(unapplied or bottom == head)
922             commit_id = git.commit(message = descr, parents = [bottom],
923                                    cache_update = False,
924                                    tree_id = top_commit.get_tree(),
925                                    allowempty = True, set_head = set_head,
926                                    author_name = author_name,
927                                    author_email = author_email,
928                                    author_date = author_date,
929                                    committer_name = committer_name,
930                                    committer_email = committer_email)
931             # set the patch top to the new commit
932             patch.set_bottom(bottom)
933             patch.set_top(commit_id)
934         else:
935             assert top != bottom
936             patch.set_bottom(bottom)
937             patch.set_top(top)
938
939         self.log_patch(patch, 'new')
940
941         return patch
942
943     def delete_patch(self, name, keep_log = False):
944         """Deletes a patch
945         """
946         self.__patch_name_valid(name)
947         patch = self.get_patch(name)
948
949         if self.__patch_is_current(patch):
950             self.pop_patch(name)
951         elif self.patch_applied(name):
952             raise StackException, 'Cannot remove an applied patch, "%s", ' \
953                   'which is not current' % name
954         elif not name in self.get_unapplied():
955             raise StackException, 'Unknown patch "%s"' % name
956
957         # save the commit id to a trash file
958         write_string(os.path.join(self.__trash_dir, name), patch.get_top())
959
960         patch.delete(keep_log = keep_log)
961
962         unapplied = self.get_unapplied()
963         unapplied.remove(name)
964         write_strings(self.__unapplied_file, unapplied)
965
966     def forward_patches(self, names):
967         """Try to fast-forward an array of patches.
968
969         On return, patches in names[0:returned_value] have been pushed on the
970         stack. Apply the rest with push_patch
971         """
972         unapplied = self.get_unapplied()
973
974         forwarded = 0
975         top = git.get_head()
976
977         for name in names:
978             assert(name in unapplied)
979
980             patch = self.get_patch(name)
981
982             head = top
983             bottom = patch.get_bottom()
984             top = patch.get_top()
985
986             # top != bottom always since we have a commit for each patch
987             if head == bottom:
988                 # reset the backup information. No logging since the
989                 # patch hasn't changed
990                 patch.set_bottom(head, backup = True)
991                 patch.set_top(top, backup = True)
992
993             else:
994                 head_tree = git.get_commit(head).get_tree()
995                 bottom_tree = git.get_commit(bottom).get_tree()
996                 if head_tree == bottom_tree:
997                     # We must just reparent this patch and create a new commit
998                     # for it
999                     descr = patch.get_description()
1000                     author_name = patch.get_authname()
1001                     author_email = patch.get_authemail()
1002                     author_date = patch.get_authdate()
1003                     committer_name = patch.get_commname()
1004                     committer_email = patch.get_commemail()
1005
1006                     top_tree = git.get_commit(top).get_tree()
1007
1008                     top = git.commit(message = descr, parents = [head],
1009                                      cache_update = False,
1010                                      tree_id = top_tree,
1011                                      allowempty = True,
1012                                      author_name = author_name,
1013                                      author_email = author_email,
1014                                      author_date = author_date,
1015                                      committer_name = committer_name,
1016                                      committer_email = committer_email)
1017
1018                     patch.set_bottom(head, backup = True)
1019                     patch.set_top(top, backup = True)
1020
1021                     self.log_patch(patch, 'push(f)')
1022                 else:
1023                     top = head
1024                     # stop the fast-forwarding, must do a real merge
1025                     break
1026
1027             forwarded+=1
1028             unapplied.remove(name)
1029
1030         if forwarded == 0:
1031             return 0
1032
1033         git.switch(top)
1034
1035         append_strings(self.__applied_file, names[0:forwarded])
1036         write_strings(self.__unapplied_file, unapplied)
1037
1038         return forwarded
1039
1040     def merged_patches(self, names):
1041         """Test which patches were merged upstream by reverse-applying
1042         them in reverse order. The function returns the list of
1043         patches detected to have been applied. The state of the tree
1044         is restored to the original one
1045         """
1046         patches = [self.get_patch(name) for name in names]
1047         patches.reverse()
1048
1049         merged = []
1050         for p in patches:
1051             if git.apply_diff(p.get_top(), p.get_bottom()):
1052                 merged.append(p.get_name())
1053         merged.reverse()
1054
1055         git.reset()
1056
1057         return merged
1058
1059     def push_empty_patch(self, name):
1060         """Pushes an empty patch on the stack
1061         """
1062         unapplied = self.get_unapplied()
1063         assert(name in unapplied)
1064
1065         # patch = self.get_patch(name)
1066         head = git.get_head()
1067
1068         append_string(self.__applied_file, name)
1069
1070         unapplied.remove(name)
1071         write_strings(self.__unapplied_file, unapplied)
1072
1073         self.refresh_patch(bottom = head, cache_update = False, log = 'push(m)')
1074
1075     def push_patch(self, name):
1076         """Pushes a patch on the stack
1077         """
1078         unapplied = self.get_unapplied()
1079         assert(name in unapplied)
1080
1081         patch = self.get_patch(name)
1082
1083         head = git.get_head()
1084         bottom = patch.get_bottom()
1085         top = patch.get_top()
1086         # top != bottom always since we have a commit for each patch
1087
1088         if head == bottom:
1089             # A fast-forward push. Just reset the backup
1090             # information. No need for logging
1091             patch.set_bottom(bottom, backup = True)
1092             patch.set_top(top, backup = True)
1093
1094             git.switch(top)
1095             append_string(self.__applied_file, name)
1096
1097             unapplied.remove(name)
1098             write_strings(self.__unapplied_file, unapplied)
1099             return False
1100
1101         # Need to create a new commit an merge in the old patch
1102         ex = None
1103         modified = False
1104
1105         # Try the fast applying first. If this fails, fall back to the
1106         # three-way merge
1107         if not git.apply_diff(bottom, top):
1108             # if git.apply_diff() fails, the patch requires a diff3
1109             # merge and can be reported as modified
1110             modified = True
1111
1112             # merge can fail but the patch needs to be pushed
1113             try:
1114                 git.merge(bottom, head, top, recursive = True)
1115             except git.GitException, ex:
1116                 out.error('The merge failed during "push".',
1117                           'Use "refresh" after fixing the conflicts or'
1118                           ' revert the operation with "push --undo".')
1119
1120         append_string(self.__applied_file, name)
1121
1122         unapplied.remove(name)
1123         write_strings(self.__unapplied_file, unapplied)
1124
1125         if not ex:
1126             # if the merge was OK and no conflicts, just refresh the patch
1127             # The GIT cache was already updated by the merge operation
1128             if modified:
1129                 log = 'push(m)'
1130             else:
1131                 log = 'push'
1132             self.refresh_patch(bottom = head, cache_update = False, log = log)
1133         else:
1134             # we store the correctly merged files only for
1135             # tracking the conflict history. Note that the
1136             # git.merge() operations should always leave the index
1137             # in a valid state (i.e. only stage 0 files)
1138             self.refresh_patch(bottom = head, cache_update = False,
1139                                log = 'push(c)')
1140             raise StackException, str(ex)
1141
1142         return modified
1143
1144     def undo_push(self):
1145         name = self.get_current()
1146         assert(name)
1147
1148         patch = self.get_patch(name)
1149         old_bottom = patch.get_old_bottom()
1150         old_top = patch.get_old_top()
1151
1152         # the top of the patch is changed by a push operation only
1153         # together with the bottom (otherwise the top was probably
1154         # modified by 'refresh'). If they are both unchanged, there
1155         # was a fast forward
1156         if old_bottom == patch.get_bottom() and old_top != patch.get_top():
1157             raise StackException, 'No undo information available'
1158
1159         git.reset()
1160         self.pop_patch(name)
1161         ret = patch.restore_old_boundaries()
1162         if ret:
1163             self.log_patch(patch, 'undo')
1164
1165         return ret
1166
1167     def pop_patch(self, name, keep = False):
1168         """Pops the top patch from the stack
1169         """
1170         applied = self.get_applied()
1171         applied.reverse()
1172         assert(name in applied)
1173
1174         patch = self.get_patch(name)
1175
1176         if git.get_head_file() == self.get_name():
1177             if keep and not git.apply_diff(git.get_head(), patch.get_bottom(),
1178                                            check_index = False):
1179                 raise StackException(
1180                     'Failed to pop patches while preserving the local changes')
1181             git.switch(patch.get_bottom(), keep)
1182         else:
1183             git.set_branch(self.get_name(), patch.get_bottom())
1184
1185         # save the new applied list
1186         idx = applied.index(name) + 1
1187
1188         popped = applied[:idx]
1189         popped.reverse()
1190         unapplied = popped + self.get_unapplied()
1191         write_strings(self.__unapplied_file, unapplied)
1192
1193         del applied[:idx]
1194         applied.reverse()
1195         write_strings(self.__applied_file, applied)
1196
1197     def empty_patch(self, name):
1198         """Returns True if the patch is empty
1199         """
1200         self.__patch_name_valid(name)
1201         patch = self.get_patch(name)
1202         bottom = patch.get_bottom()
1203         top = patch.get_top()
1204
1205         if bottom == top:
1206             return True
1207         elif git.get_commit(top).get_tree() \
1208                  == git.get_commit(bottom).get_tree():
1209             return True
1210
1211         return False
1212
1213     def rename_patch(self, oldname, newname):
1214         self.__patch_name_valid(newname)
1215
1216         applied = self.get_applied()
1217         unapplied = self.get_unapplied()
1218
1219         if oldname == newname:
1220             raise StackException, '"To" name and "from" name are the same'
1221
1222         if newname in applied or newname in unapplied:
1223             raise StackException, 'Patch "%s" already exists' % newname
1224
1225         if oldname in unapplied:
1226             self.get_patch(oldname).rename(newname)
1227             unapplied[unapplied.index(oldname)] = newname
1228             write_strings(self.__unapplied_file, unapplied)
1229         elif oldname in applied:
1230             self.get_patch(oldname).rename(newname)
1231
1232             applied[applied.index(oldname)] = newname
1233             write_strings(self.__applied_file, applied)
1234         else:
1235             raise StackException, 'Unknown patch "%s"' % oldname
1236
1237     def log_patch(self, patch, message, notes = None):
1238         """Generate a log commit for a patch
1239         """
1240         top = git.get_commit(patch.get_top())
1241         old_log = patch.get_log()
1242
1243         if message is None:
1244             # replace the current log entry
1245             if not old_log:
1246                 raise StackException, \
1247                       'No log entry to annotate for patch "%s"' \
1248                       % patch.get_name()
1249             replace = True
1250             log_commit = git.get_commit(old_log)
1251             msg = log_commit.get_log().split('\n')[0]
1252             log_parent = log_commit.get_parent()
1253             if log_parent:
1254                 parents = [log_parent]
1255             else:
1256                 parents = []
1257         else:
1258             # generate a new log entry
1259             replace = False
1260             msg = '%s\t%s' % (message, top.get_id_hash())
1261             if old_log:
1262                 parents = [old_log]
1263             else:
1264                 parents = []
1265
1266         if notes:
1267             msg += '\n\n' + notes
1268
1269         log = git.commit(message = msg, parents = parents,
1270                          cache_update = False, tree_id = top.get_tree(),
1271                          allowempty = True)
1272         patch.set_log(log)
1273
1274     def hide_patch(self, name):
1275         """Add the patch to the hidden list.
1276         """
1277         unapplied = self.get_unapplied()
1278         if name not in unapplied:
1279             # keep the checking order for backward compatibility with
1280             # the old hidden patches functionality
1281             if self.patch_applied(name):
1282                 raise StackException, 'Cannot hide applied patch "%s"' % name
1283             elif self.patch_hidden(name):
1284                 raise StackException, 'Patch "%s" already hidden' % name
1285             else:
1286                 raise StackException, 'Unknown patch "%s"' % name
1287
1288         if not self.patch_hidden(name):
1289             # check needed for backward compatibility with the old
1290             # hidden patches functionality
1291             append_string(self.__hidden_file, name)
1292
1293         unapplied.remove(name)
1294         write_strings(self.__unapplied_file, unapplied)
1295
1296     def unhide_patch(self, name):
1297         """Remove the patch from the hidden list.
1298         """
1299         hidden = self.get_hidden()
1300         if not name in hidden:
1301             if self.patch_applied(name) or self.patch_unapplied(name):
1302                 raise StackException, 'Patch "%s" not hidden' % name
1303             else:
1304                 raise StackException, 'Unknown patch "%s"' % name
1305
1306         hidden.remove(name)
1307         write_strings(self.__hidden_file, hidden)
1308
1309         if not self.patch_applied(name) and not self.patch_unapplied(name):
1310             # check needed for backward compatibility with the old
1311             # hidden patches functionality
1312             append_string(self.__unapplied_file, name)