chiark / gitweb /
bd08b35ffede59286893b4b67271f579ab816431
[stgit] / stgit / stack.py
1 """Basic quilt-like functionality
2 """
3
4 __copyright__ = """
5 Copyright (C) 2005, Catalin Marinas <catalin.marinas@gmail.com>
6
7 This program is free software; you can redistribute it and/or modify
8 it under the terms of the GNU General Public License version 2 as
9 published by the Free Software Foundation.
10
11 This program is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14 GNU General Public License for more details.
15
16 You should have received a copy of the GNU General Public License
17 along with this program; if not, write to the Free Software
18 Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
19 """
20
21 import sys, os, re
22 from email.Utils import formatdate
23
24 from stgit.utils import *
25 from stgit.out import *
26 from stgit.run import *
27 from stgit import git, basedir, templates
28 from stgit.config import config
29 from shutil import copyfile
30
31
32 # stack exception class
33 class StackException(Exception):
34     pass
35
36 class FilterUntil:
37     def __init__(self):
38         self.should_print = True
39     def __call__(self, x, until_test, prefix):
40         if until_test(x):
41             self.should_print = False
42         if self.should_print:
43             return x[0:len(prefix)] != prefix
44         return False
45
46 #
47 # Functions
48 #
49 __comment_prefix = 'STG:'
50 __patch_prefix = 'STG_PATCH:'
51
52 def __clean_comments(f):
53     """Removes lines marked for status in a commit file
54     """
55     f.seek(0)
56
57     # remove status-prefixed lines
58     lines = f.readlines()
59
60     patch_filter = FilterUntil()
61     until_test = lambda t: t == (__patch_prefix + '\n')
62     lines = [l for l in lines if patch_filter(l, until_test, __comment_prefix)]
63
64     # remove empty lines at the end
65     while len(lines) != 0 and lines[-1] == '\n':
66         del lines[-1]
67
68     f.seek(0); f.truncate()
69     f.writelines(lines)
70
71 # TODO: move this out of the stgit.stack module, it is really for
72 # higher level commands to handle the user interaction
73 def edit_file(series, line, comment, show_patch = True):
74     fname = '.stgitmsg.txt'
75     tmpl = templates.get_template('patchdescr.tmpl')
76
77     f = file(fname, 'w+')
78     if line:
79         print >> f, line
80     elif tmpl:
81         print >> f, tmpl,
82     else:
83         print >> f
84     print >> f, __comment_prefix, comment
85     print >> f, __comment_prefix, \
86           'Lines prefixed with "%s" will be automatically removed.' \
87           % __comment_prefix
88     print >> f, __comment_prefix, \
89           'Trailing empty lines will be automatically removed.'
90
91     if show_patch:
92        print >> f, __patch_prefix
93        # series.get_patch(series.get_current()).get_top()
94        diff_str = git.diff(rev1 = series.get_patch(series.get_current()).get_bottom())
95        f.write(diff_str)
96
97     #Vim modeline must be near the end.
98     print >> f, __comment_prefix, 'vi: set textwidth=75 filetype=diff nobackup:'
99     f.close()
100
101     call_editor(fname)
102
103     f = file(fname, 'r+')
104
105     __clean_comments(f)
106     f.seek(0)
107     result = f.read()
108
109     f.close()
110     os.remove(fname)
111
112     return result
113
114 #
115 # Classes
116 #
117
118 class StgitObject:
119     """An object with stgit-like properties stored as files in a directory
120     """
121     def _set_dir(self, dir):
122         self.__dir = dir
123     def _dir(self):
124         return self.__dir
125
126     def create_empty_field(self, name):
127         create_empty_file(os.path.join(self.__dir, name))
128
129     def _get_field(self, name, multiline = False):
130         id_file = os.path.join(self.__dir, name)
131         if os.path.isfile(id_file):
132             line = read_string(id_file, multiline)
133             if line == '':
134                 return None
135             else:
136                 return line
137         else:
138             return None
139
140     def _set_field(self, name, value, multiline = False):
141         fname = os.path.join(self.__dir, name)
142         if value and value != '':
143             write_string(fname, value, multiline)
144         elif os.path.isfile(fname):
145             os.remove(fname)
146
147     
148 class Patch(StgitObject):
149     """Basic patch implementation
150     """
151     def __init_refs(self):
152         self.__top_ref = self.__refs_base + '/' + self.__name
153         self.__log_ref = self.__top_ref + '.log'
154
155     def __init__(self, name, series_dir, refs_base):
156         self.__series_dir = series_dir
157         self.__name = name
158         self._set_dir(os.path.join(self.__series_dir, self.__name))
159         self.__refs_base = refs_base
160         self.__init_refs()
161
162     def create(self):
163         os.mkdir(self._dir())
164         self.create_empty_field('bottom')
165         self.create_empty_field('top')
166
167     def delete(self):
168         for f in os.listdir(self._dir()):
169             os.remove(os.path.join(self._dir(), f))
170         os.rmdir(self._dir())
171         git.delete_ref(self.__top_ref)
172         if git.ref_exists(self.__log_ref):
173             git.delete_ref(self.__log_ref)
174
175     def get_name(self):
176         return self.__name
177
178     def rename(self, newname):
179         olddir = self._dir()
180         old_top_ref = self.__top_ref
181         old_log_ref = self.__log_ref
182         self.__name = newname
183         self._set_dir(os.path.join(self.__series_dir, self.__name))
184         self.__init_refs()
185
186         git.rename_ref(old_top_ref, self.__top_ref)
187         if git.ref_exists(old_log_ref):
188             git.rename_ref(old_log_ref, self.__log_ref)
189         os.rename(olddir, self._dir())
190
191     def __update_top_ref(self, ref):
192         git.set_ref(self.__top_ref, ref)
193
194     def __update_log_ref(self, ref):
195         git.set_ref(self.__log_ref, ref)
196
197     def update_top_ref(self):
198         top = self.get_top()
199         if top:
200             self.__update_top_ref(top)
201
202     def get_old_bottom(self):
203         return self._get_field('bottom.old')
204
205     def get_bottom(self):
206         return self._get_field('bottom')
207
208     def set_bottom(self, value, backup = False):
209         if backup:
210             curr = self._get_field('bottom')
211             self._set_field('bottom.old', curr)
212         self._set_field('bottom', value)
213
214     def get_old_top(self):
215         return self._get_field('top.old')
216
217     def get_top(self):
218         return self._get_field('top')
219
220     def set_top(self, value, backup = False):
221         if backup:
222             curr = self._get_field('top')
223             self._set_field('top.old', curr)
224         self._set_field('top', value)
225         self.__update_top_ref(value)
226
227     def restore_old_boundaries(self):
228         bottom = self._get_field('bottom.old')
229         top = self._get_field('top.old')
230
231         if top and bottom:
232             self._set_field('bottom', bottom)
233             self._set_field('top', top)
234             self.__update_top_ref(top)
235             return True
236         else:
237             return False
238
239     def get_description(self):
240         return self._get_field('description', True)
241
242     def set_description(self, line):
243         self._set_field('description', line, True)
244
245     def get_authname(self):
246         return self._get_field('authname')
247
248     def set_authname(self, name):
249         self._set_field('authname', name or git.author().name)
250
251     def get_authemail(self):
252         return self._get_field('authemail')
253
254     def set_authemail(self, email):
255         self._set_field('authemail', email or git.author().email)
256
257     def get_authdate(self):
258         date = self._get_field('authdate')
259         if not date:
260             return date
261
262         if re.match('[0-9]+\s+[+-][0-9]+', date):
263             # Unix time (seconds) + time zone
264             secs_tz = date.split()
265             date = formatdate(int(secs_tz[0]))[:-5] + secs_tz[1]
266
267         return date
268
269     def set_authdate(self, date):
270         self._set_field('authdate', date or git.author().date)
271
272     def get_commname(self):
273         return self._get_field('commname')
274
275     def set_commname(self, name):
276         self._set_field('commname', name or git.committer().name)
277
278     def get_commemail(self):
279         return self._get_field('commemail')
280
281     def set_commemail(self, email):
282         self._set_field('commemail', email or git.committer().email)
283
284     def get_log(self):
285         return self._get_field('log')
286
287     def set_log(self, value, backup = False):
288         self._set_field('log', value)
289         self.__update_log_ref(value)
290
291 # The current StGIT metadata format version.
292 FORMAT_VERSION = 2
293
294 class PatchSet(StgitObject):
295     def __init__(self, name = None):
296         try:
297             if name:
298                 self.set_name (name)
299             else:
300                 self.set_name (git.get_head_file())
301             self.__base_dir = basedir.get()
302         except git.GitException, ex:
303             raise StackException, 'GIT tree not initialised: %s' % ex
304
305         self._set_dir(os.path.join(self.__base_dir, 'patches', self.get_name()))
306
307     def get_name(self):
308         return self.__name
309     def set_name(self, name):
310         self.__name = name
311
312     def _basedir(self):
313         return self.__base_dir
314
315     def get_head(self):
316         """Return the head of the branch
317         """
318         crt = self.get_current_patch()
319         if crt:
320             return crt.get_top()
321         else:
322             return self.get_base()
323
324     def get_protected(self):
325         return os.path.isfile(os.path.join(self._dir(), 'protected'))
326
327     def protect(self):
328         protect_file = os.path.join(self._dir(), 'protected')
329         if not os.path.isfile(protect_file):
330             create_empty_file(protect_file)
331
332     def unprotect(self):
333         protect_file = os.path.join(self._dir(), 'protected')
334         if os.path.isfile(protect_file):
335             os.remove(protect_file)
336
337     def __branch_descr(self):
338         return 'branch.%s.description' % self.get_name()
339
340     def get_description(self):
341         return config.get(self.__branch_descr()) or ''
342
343     def set_description(self, line):
344         if line:
345             config.set(self.__branch_descr(), line)
346         else:
347             config.unset(self.__branch_descr())
348
349     def head_top_equal(self):
350         """Return true if the head and the top are the same
351         """
352         crt = self.get_current_patch()
353         if not crt:
354             # we don't care, no patches applied
355             return True
356         return git.get_head() == crt.get_top()
357
358     def is_initialised(self):
359         """Checks if series is already initialised
360         """
361         return bool(config.get(self.format_version_key()))
362
363
364 def shortlog(patches):
365     log = ''.join(Run('git-log', '--pretty=short',
366                       p.get_top(), '^%s' % p.get_bottom()).raw_output()
367                   for p in patches)
368     return Run('git-shortlog').raw_input(log).raw_output()
369
370 class Series(PatchSet):
371     """Class including the operations on series
372     """
373     def __init__(self, name = None):
374         """Takes a series name as the parameter.
375         """
376         PatchSet.__init__(self, name)
377
378         # Update the branch to the latest format version if it is
379         # initialized, but don't touch it if it isn't.
380         self.update_to_current_format_version()
381
382         self.__refs_base = 'refs/patches/%s' % self.get_name()
383
384         self.__applied_file = os.path.join(self._dir(), 'applied')
385         self.__unapplied_file = os.path.join(self._dir(), 'unapplied')
386         self.__hidden_file = os.path.join(self._dir(), 'hidden')
387
388         # where this series keeps its patches
389         self.__patch_dir = os.path.join(self._dir(), 'patches')
390
391         # trash directory
392         self.__trash_dir = os.path.join(self._dir(), 'trash')
393
394     def format_version_key(self):
395         return 'branch.%s.stgit.stackformatversion' % self.get_name()
396
397     def update_to_current_format_version(self):
398         """Update a potentially older StGIT directory structure to the
399         latest version. Note: This function should depend as little as
400         possible on external functions that may change during a format
401         version bump, since it must remain able to process older formats."""
402
403         branch_dir = os.path.join(self._basedir(), 'patches', self.get_name())
404         def get_format_version():
405             """Return the integer format version number, or None if the
406             branch doesn't have any StGIT metadata at all, of any version."""
407             fv = config.get(self.format_version_key())
408             ofv = config.get('branch.%s.stgitformatversion' % self.get_name())
409             if fv:
410                 # Great, there's an explicitly recorded format version
411                 # number, which means that the branch is initialized and
412                 # of that exact version.
413                 return int(fv)
414             elif ofv:
415                 # Old name for the version info, upgrade it
416                 config.set(self.format_version_key(), ofv)
417                 config.unset('branch.%s.stgitformatversion' % self.get_name())
418                 return int(ofv)
419             elif os.path.isdir(os.path.join(branch_dir, 'patches')):
420                 # There's a .git/patches/<branch>/patches dirctory, which
421                 # means this is an initialized version 1 branch.
422                 return 1
423             elif os.path.isdir(branch_dir):
424                 # There's a .git/patches/<branch> directory, which means
425                 # this is an initialized version 0 branch.
426                 return 0
427             else:
428                 # The branch doesn't seem to be initialized at all.
429                 return None
430         def set_format_version(v):
431             out.info('Upgraded branch %s to format version %d' % (self.get_name(), v))
432             config.set(self.format_version_key(), '%d' % v)
433         def mkdir(d):
434             if not os.path.isdir(d):
435                 os.makedirs(d)
436         def rm(f):
437             if os.path.exists(f):
438                 os.remove(f)
439         def rm_ref(ref):
440             if git.ref_exists(ref):
441                 git.delete_ref(ref)
442
443         # Update 0 -> 1.
444         if get_format_version() == 0:
445             mkdir(os.path.join(branch_dir, 'trash'))
446             patch_dir = os.path.join(branch_dir, 'patches')
447             mkdir(patch_dir)
448             refs_base = 'refs/patches/%s' % self.get_name()
449             for patch in (file(os.path.join(branch_dir, 'unapplied')).readlines()
450                           + file(os.path.join(branch_dir, 'applied')).readlines()):
451                 patch = patch.strip()
452                 os.rename(os.path.join(branch_dir, patch),
453                           os.path.join(patch_dir, patch))
454                 Patch(patch, patch_dir, refs_base).update_top_ref()
455             set_format_version(1)
456
457         # Update 1 -> 2.
458         if get_format_version() == 1:
459             desc_file = os.path.join(branch_dir, 'description')
460             if os.path.isfile(desc_file):
461                 desc = read_string(desc_file)
462                 if desc:
463                     config.set('branch.%s.description' % self.get_name(), desc)
464                 rm(desc_file)
465             rm(os.path.join(branch_dir, 'current'))
466             rm_ref('refs/bases/%s' % self.get_name())
467             set_format_version(2)
468
469         # Make sure we're at the latest version.
470         if not get_format_version() in [None, FORMAT_VERSION]:
471             raise StackException('Branch %s is at format version %d, expected %d'
472                                  % (self.get_name(), get_format_version(), FORMAT_VERSION))
473
474     def __patch_name_valid(self, name):
475         """Raise an exception if the patch name is not valid.
476         """
477         if not name or re.search('[^\w.-]', name):
478             raise StackException, 'Invalid patch name: "%s"' % name
479
480     def get_patch(self, name):
481         """Return a Patch object for the given name
482         """
483         return Patch(name, self.__patch_dir, self.__refs_base)
484
485     def get_current_patch(self):
486         """Return a Patch object representing the topmost patch, or
487         None if there is no such patch."""
488         crt = self.get_current()
489         if not crt:
490             return None
491         return self.get_patch(crt)
492
493     def get_current(self):
494         """Return the name of the topmost patch, or None if there is
495         no such patch."""
496         try:
497             applied = self.get_applied()
498         except StackException:
499             # No "applied" file: branch is not initialized.
500             return None
501         try:
502             return applied[-1]
503         except IndexError:
504             # No patches applied.
505             return None
506
507     def get_applied(self):
508         if not os.path.isfile(self.__applied_file):
509             raise StackException, 'Branch "%s" not initialised' % self.get_name()
510         return read_strings(self.__applied_file)
511
512     def get_unapplied(self):
513         if not os.path.isfile(self.__unapplied_file):
514             raise StackException, 'Branch "%s" not initialised' % self.get_name()
515         return read_strings(self.__unapplied_file)
516
517     def get_hidden(self):
518         if not os.path.isfile(self.__hidden_file):
519             return []
520         return read_strings(self.__hidden_file)
521
522     def get_base(self):
523         # Return the parent of the bottommost patch, if there is one.
524         if os.path.isfile(self.__applied_file):
525             bottommost = file(self.__applied_file).readline().strip()
526             if bottommost:
527                 return self.get_patch(bottommost).get_bottom()
528         # No bottommost patch, so just return HEAD
529         return git.get_head()
530
531     def get_parent_remote(self):
532         value = config.get('branch.%s.remote' % self.get_name())
533         if value:
534             return value
535         elif 'origin' in git.remotes_list():
536             out.note(('No parent remote declared for stack "%s",'
537                       ' defaulting to "origin".' % self.get_name()),
538                      ('Consider setting "branch.%s.remote" and'
539                       ' "branch.%s.merge" with "git config".'
540                       % (self.get_name(), self.get_name())))
541             return 'origin'
542         else:
543             raise StackException, 'Cannot find a parent remote for "%s"' % self.get_name()
544
545     def __set_parent_remote(self, remote):
546         value = config.set('branch.%s.remote' % self.get_name(), remote)
547
548     def get_parent_branch(self):
549         value = config.get('branch.%s.stgit.parentbranch' % self.get_name())
550         if value:
551             return value
552         elif git.rev_parse('heads/origin'):
553             out.note(('No parent branch declared for stack "%s",'
554                       ' defaulting to "heads/origin".' % self.get_name()),
555                      ('Consider setting "branch.%s.stgit.parentbranch"'
556                       ' with "git config".' % self.get_name()))
557             return 'heads/origin'
558         else:
559             raise StackException, 'Cannot find a parent branch for "%s"' % self.get_name()
560
561     def __set_parent_branch(self, name):
562         if config.get('branch.%s.remote' % self.get_name()):
563             # Never set merge if remote is not set to avoid
564             # possibly-erroneous lookups into 'origin'
565             config.set('branch.%s.merge' % self.get_name(), name)
566         config.set('branch.%s.stgit.parentbranch' % self.get_name(), name)
567
568     def set_parent(self, remote, localbranch):
569         if localbranch:
570             if remote:
571                 self.__set_parent_remote(remote)
572             self.__set_parent_branch(localbranch)
573         # We'll enforce this later
574 #         else:
575 #             raise StackException, 'Parent branch (%s) should be specified for %s' % localbranch, self.get_name()
576
577     def __patch_is_current(self, patch):
578         return patch.get_name() == self.get_current()
579
580     def patch_applied(self, name):
581         """Return true if the patch exists in the applied list
582         """
583         return name in self.get_applied()
584
585     def patch_unapplied(self, name):
586         """Return true if the patch exists in the unapplied list
587         """
588         return name in self.get_unapplied()
589
590     def patch_hidden(self, name):
591         """Return true if the patch is hidden.
592         """
593         return name in self.get_hidden()
594
595     def patch_exists(self, name):
596         """Return true if there is a patch with the given name, false
597         otherwise."""
598         return self.patch_applied(name) or self.patch_unapplied(name) \
599                or self.patch_hidden(name)
600
601     def init(self, create_at=False, parent_remote=None, parent_branch=None):
602         """Initialises the stgit series
603         """
604         if self.is_initialised():
605             raise StackException, '%s already initialized' % self.get_name()
606         for d in [self._dir()]:
607             if os.path.exists(d):
608                 raise StackException, '%s already exists' % d
609
610         if (create_at!=False):
611             git.create_branch(self.get_name(), create_at)
612
613         os.makedirs(self.__patch_dir)
614
615         self.set_parent(parent_remote, parent_branch)
616
617         self.create_empty_field('applied')
618         self.create_empty_field('unapplied')
619         self._set_field('orig-base', git.get_head())
620
621         config.set(self.format_version_key(), str(FORMAT_VERSION))
622
623     def rename(self, to_name):
624         """Renames a series
625         """
626         to_stack = Series(to_name)
627
628         if to_stack.is_initialised():
629             raise StackException, '"%s" already exists' % to_stack.get_name()
630
631         patches = self.get_applied() + self.get_unapplied()
632
633         git.rename_branch(self.get_name(), to_name)
634
635         for patch in patches:
636             git.rename_ref('refs/patches/%s/%s' % (self.get_name(), patch),
637                            'refs/patches/%s/%s' % (to_name, patch))
638             git.rename_ref('refs/patches/%s/%s.log' % (self.get_name(), patch),
639                            'refs/patches/%s/%s.log' % (to_name, patch))
640         if os.path.isdir(self._dir()):
641             rename(os.path.join(self._basedir(), 'patches'),
642                    self.get_name(), to_stack.get_name())
643
644         # Rename the config section
645         for k in ['branch.%s', 'branch.%s.stgit']:
646             config.rename_section(k % self.get_name(), k % to_name)
647
648         self.__init__(to_name)
649
650     def clone(self, target_series):
651         """Clones a series
652         """
653         try:
654             # allow cloning of branches not under StGIT control
655             base = self.get_base()
656         except:
657             base = git.get_head()
658         Series(target_series).init(create_at = base)
659         new_series = Series(target_series)
660
661         # generate an artificial description file
662         new_series.set_description('clone of "%s"' % self.get_name())
663
664         # clone self's entire series as unapplied patches
665         try:
666             # allow cloning of branches not under StGIT control
667             applied = self.get_applied()
668             unapplied = self.get_unapplied()
669             patches = applied + unapplied
670             patches.reverse()
671         except:
672             patches = applied = unapplied = []
673         for p in patches:
674             patch = self.get_patch(p)
675             newpatch = new_series.new_patch(p, message = patch.get_description(),
676                                             can_edit = False, unapplied = True,
677                                             bottom = patch.get_bottom(),
678                                             top = patch.get_top(),
679                                             author_name = patch.get_authname(),
680                                             author_email = patch.get_authemail(),
681                                             author_date = patch.get_authdate())
682             if patch.get_log():
683                 out.info('Setting log to %s' %  patch.get_log())
684                 newpatch.set_log(patch.get_log())
685             else:
686                 out.info('No log for %s' % p)
687
688         # fast forward the cloned series to self's top
689         new_series.forward_patches(applied)
690
691         # Clone parent informations
692         value = config.get('branch.%s.remote' % self.get_name())
693         if value:
694             config.set('branch.%s.remote' % target_series, value)
695
696         value = config.get('branch.%s.merge' % self.get_name())
697         if value:
698             config.set('branch.%s.merge' % target_series, value)
699
700         value = config.get('branch.%s.stgit.parentbranch' % self.get_name())
701         if value:
702             config.set('branch.%s.stgit.parentbranch' % target_series, value)
703
704     def delete(self, force = False):
705         """Deletes an stgit series
706         """
707         if self.is_initialised():
708             patches = self.get_unapplied() + self.get_applied()
709             if not force and patches:
710                 raise StackException, \
711                       'Cannot delete: the series still contains patches'
712             for p in patches:
713                 self.get_patch(p).delete()
714
715             # remove the trash directory if any
716             if os.path.exists(self.__trash_dir):
717                 for fname in os.listdir(self.__trash_dir):
718                     os.remove(os.path.join(self.__trash_dir, fname))
719                 os.rmdir(self.__trash_dir)
720
721             # FIXME: find a way to get rid of those manual removals
722             # (move functionality to StgitObject ?)
723             if os.path.exists(self.__applied_file):
724                 os.remove(self.__applied_file)
725             if os.path.exists(self.__unapplied_file):
726                 os.remove(self.__unapplied_file)
727             if os.path.exists(self.__hidden_file):
728                 os.remove(self.__hidden_file)
729             if os.path.exists(self._dir()+'/orig-base'):
730                 os.remove(self._dir()+'/orig-base')
731
732             if not os.listdir(self.__patch_dir):
733                 os.rmdir(self.__patch_dir)
734             else:
735                 out.warn('Patch directory %s is not empty' % self.__patch_dir)
736
737             try:
738                 os.removedirs(self._dir())
739             except OSError:
740                 raise StackException('Series directory %s is not empty'
741                                      % self._dir())
742
743             try:
744                 git.delete_branch(self.get_name())
745             except GitException:
746                 out.warn('Could not delete branch "%s"' % self.get_name())
747
748         config.remove_section('branch.%s' % self.get_name())
749         config.remove_section('branch.%s.stgit' % self.get_name())
750
751     def refresh_patch(self, files = None, message = None, edit = False,
752                       show_patch = False,
753                       cache_update = True,
754                       author_name = None, author_email = None,
755                       author_date = None,
756                       committer_name = None, committer_email = None,
757                       backup = False, sign_str = None, log = 'refresh',
758                       notes = None, bottom = None):
759         """Generates a new commit for the topmost patch
760         """
761         patch = self.get_current_patch()
762         if not patch:
763             raise StackException, 'No patches applied'
764
765         descr = patch.get_description()
766         if not (message or descr):
767             edit = True
768             descr = ''
769         elif message:
770             descr = message
771
772         # TODO: move this out of the stgit.stack module, it is really
773         # for higher level commands to handle the user interaction
774         if not message and edit:
775             descr = edit_file(self, descr.rstrip(), \
776                               'Please edit the description for patch "%s" ' \
777                               'above.' % patch.get_name(), show_patch)
778
779         if not author_name:
780             author_name = patch.get_authname()
781         if not author_email:
782             author_email = patch.get_authemail()
783         if not author_date:
784             author_date = patch.get_authdate()
785         if not committer_name:
786             committer_name = patch.get_commname()
787         if not committer_email:
788             committer_email = patch.get_commemail()
789
790         descr = add_sign_line(descr, sign_str, committer_name, committer_email)
791
792         if not bottom:
793             bottom = patch.get_bottom()
794
795         commit_id = git.commit(files = files,
796                                message = descr, parents = [bottom],
797                                cache_update = cache_update,
798                                allowempty = True,
799                                author_name = author_name,
800                                author_email = author_email,
801                                author_date = author_date,
802                                committer_name = committer_name,
803                                committer_email = committer_email)
804
805         patch.set_bottom(bottom, backup = backup)
806         patch.set_top(commit_id, backup = backup)
807         patch.set_description(descr)
808         patch.set_authname(author_name)
809         patch.set_authemail(author_email)
810         patch.set_authdate(author_date)
811         patch.set_commname(committer_name)
812         patch.set_commemail(committer_email)
813
814         if log:
815             self.log_patch(patch, log, notes)
816
817         return commit_id
818
819     def undo_refresh(self):
820         """Undo the patch boundaries changes caused by 'refresh'
821         """
822         name = self.get_current()
823         assert(name)
824
825         patch = self.get_patch(name)
826         old_bottom = patch.get_old_bottom()
827         old_top = patch.get_old_top()
828
829         # the bottom of the patch is not changed by refresh. If the
830         # old_bottom is different, there wasn't any previous 'refresh'
831         # command (probably only a 'push')
832         if old_bottom != patch.get_bottom() or old_top == patch.get_top():
833             raise StackException, 'No undo information available'
834
835         git.reset(tree_id = old_top, check_out = False)
836         if patch.restore_old_boundaries():
837             self.log_patch(patch, 'undo')
838
839     def new_patch(self, name, message = None, can_edit = True,
840                   unapplied = False, show_patch = False,
841                   top = None, bottom = None, commit = True,
842                   author_name = None, author_email = None, author_date = None,
843                   committer_name = None, committer_email = None,
844                   before_existing = False):
845         """Creates a new patch, either pointing to an existing commit object,
846         or by creating a new commit object.
847         """
848
849         assert commit or (top and bottom)
850         assert not before_existing or (top and bottom)
851         assert not (commit and before_existing)
852         assert (top and bottom) or (not top and not bottom)
853         assert not top or (bottom == git.get_commit(top).get_parent())
854
855         if name != None:
856             self.__patch_name_valid(name)
857             if self.patch_exists(name):
858                 raise StackException, 'Patch "%s" already exists' % name
859
860         # TODO: move this out of the stgit.stack module, it is really
861         # for higher level commands to handle the user interaction
862         if not message and can_edit:
863             descr = edit_file(
864                 self, None,
865                 'Please enter the description for the patch above.',
866                 show_patch)
867         else:
868             descr = message
869
870         head = git.get_head()
871
872         if name == None:
873             name = make_patch_name(descr, self.patch_exists)
874
875         patch = self.get_patch(name)
876         patch.create()
877
878         patch.set_description(descr)
879         patch.set_authname(author_name)
880         patch.set_authemail(author_email)
881         patch.set_authdate(author_date)
882         patch.set_commname(committer_name)
883         patch.set_commemail(committer_email)
884
885         if before_existing:
886             insert_string(self.__applied_file, patch.get_name())
887         elif unapplied:
888             patches = [patch.get_name()] + self.get_unapplied()
889             write_strings(self.__unapplied_file, patches)
890             set_head = False
891         else:
892             append_string(self.__applied_file, patch.get_name())
893             set_head = True
894
895         if commit:
896             if top:
897                 top_commit = git.get_commit(top)
898             else:
899                 bottom = head
900                 top_commit = git.get_commit(head)
901
902             # create a commit for the patch (may be empty if top == bottom);
903             # only commit on top of the current branch
904             assert(unapplied or bottom == head)
905             commit_id = git.commit(message = descr, parents = [bottom],
906                                    cache_update = False,
907                                    tree_id = top_commit.get_tree(),
908                                    allowempty = True, set_head = set_head,
909                                    author_name = author_name,
910                                    author_email = author_email,
911                                    author_date = author_date,
912                                    committer_name = committer_name,
913                                    committer_email = committer_email)
914             # set the patch top to the new commit
915             patch.set_bottom(bottom)
916             patch.set_top(commit_id)
917         else:
918             assert top != bottom
919             patch.set_bottom(bottom)
920             patch.set_top(top)
921
922         self.log_patch(patch, 'new')
923
924         return patch
925
926     def delete_patch(self, name):
927         """Deletes a patch
928         """
929         self.__patch_name_valid(name)
930         patch = self.get_patch(name)
931
932         if self.__patch_is_current(patch):
933             self.pop_patch(name)
934         elif self.patch_applied(name):
935             raise StackException, 'Cannot remove an applied patch, "%s", ' \
936                   'which is not current' % name
937         elif not name in self.get_unapplied():
938             raise StackException, 'Unknown patch "%s"' % name
939
940         # save the commit id to a trash file
941         write_string(os.path.join(self.__trash_dir, name), patch.get_top())
942
943         patch.delete()
944
945         unapplied = self.get_unapplied()
946         unapplied.remove(name)
947         write_strings(self.__unapplied_file, unapplied)
948
949     def forward_patches(self, names):
950         """Try to fast-forward an array of patches.
951
952         On return, patches in names[0:returned_value] have been pushed on the
953         stack. Apply the rest with push_patch
954         """
955         unapplied = self.get_unapplied()
956
957         forwarded = 0
958         top = git.get_head()
959
960         for name in names:
961             assert(name in unapplied)
962
963             patch = self.get_patch(name)
964
965             head = top
966             bottom = patch.get_bottom()
967             top = patch.get_top()
968
969             # top != bottom always since we have a commit for each patch
970             if head == bottom:
971                 # reset the backup information. No logging since the
972                 # patch hasn't changed
973                 patch.set_bottom(head, backup = True)
974                 patch.set_top(top, backup = True)
975
976             else:
977                 head_tree = git.get_commit(head).get_tree()
978                 bottom_tree = git.get_commit(bottom).get_tree()
979                 if head_tree == bottom_tree:
980                     # We must just reparent this patch and create a new commit
981                     # for it
982                     descr = patch.get_description()
983                     author_name = patch.get_authname()
984                     author_email = patch.get_authemail()
985                     author_date = patch.get_authdate()
986                     committer_name = patch.get_commname()
987                     committer_email = patch.get_commemail()
988
989                     top_tree = git.get_commit(top).get_tree()
990
991                     top = git.commit(message = descr, parents = [head],
992                                      cache_update = False,
993                                      tree_id = top_tree,
994                                      allowempty = True,
995                                      author_name = author_name,
996                                      author_email = author_email,
997                                      author_date = author_date,
998                                      committer_name = committer_name,
999                                      committer_email = committer_email)
1000
1001                     patch.set_bottom(head, backup = True)
1002                     patch.set_top(top, backup = True)
1003
1004                     self.log_patch(patch, 'push(f)')
1005                 else:
1006                     top = head
1007                     # stop the fast-forwarding, must do a real merge
1008                     break
1009
1010             forwarded+=1
1011             unapplied.remove(name)
1012
1013         if forwarded == 0:
1014             return 0
1015
1016         git.switch(top)
1017
1018         append_strings(self.__applied_file, names[0:forwarded])
1019         write_strings(self.__unapplied_file, unapplied)
1020
1021         return forwarded
1022
1023     def merged_patches(self, names):
1024         """Test which patches were merged upstream by reverse-applying
1025         them in reverse order. The function returns the list of
1026         patches detected to have been applied. The state of the tree
1027         is restored to the original one
1028         """
1029         patches = [self.get_patch(name) for name in names]
1030         patches.reverse()
1031
1032         merged = []
1033         for p in patches:
1034             if git.apply_diff(p.get_top(), p.get_bottom()):
1035                 merged.append(p.get_name())
1036         merged.reverse()
1037
1038         git.reset()
1039
1040         return merged
1041
1042     def push_empty_patch(self, name):
1043         """Pushes an empty patch on the stack
1044         """
1045         unapplied = self.get_unapplied()
1046         assert(name in unapplied)
1047
1048         # patch = self.get_patch(name)
1049         head = git.get_head()
1050
1051         append_string(self.__applied_file, name)
1052
1053         unapplied.remove(name)
1054         write_strings(self.__unapplied_file, unapplied)
1055
1056         self.refresh_patch(bottom = head, cache_update = False, log = 'push(m)')
1057
1058     def push_patch(self, name):
1059         """Pushes a patch on the stack
1060         """
1061         unapplied = self.get_unapplied()
1062         assert(name in unapplied)
1063
1064         patch = self.get_patch(name)
1065
1066         head = git.get_head()
1067         bottom = patch.get_bottom()
1068         top = patch.get_top()
1069         # top != bottom always since we have a commit for each patch
1070
1071         if head == bottom:
1072             # A fast-forward push. Just reset the backup
1073             # information. No need for logging
1074             patch.set_bottom(bottom, backup = True)
1075             patch.set_top(top, backup = True)
1076
1077             git.switch(top)
1078             append_string(self.__applied_file, name)
1079
1080             unapplied.remove(name)
1081             write_strings(self.__unapplied_file, unapplied)
1082             return False
1083
1084         # Need to create a new commit an merge in the old patch
1085         ex = None
1086         modified = False
1087
1088         # Try the fast applying first. If this fails, fall back to the
1089         # three-way merge
1090         if not git.apply_diff(bottom, top):
1091             # if git.apply_diff() fails, the patch requires a diff3
1092             # merge and can be reported as modified
1093             modified = True
1094
1095             # merge can fail but the patch needs to be pushed
1096             try:
1097                 git.merge(bottom, head, top, recursive = True)
1098             except git.GitException, ex:
1099                 out.error('The merge failed during "push".',
1100                           'Use "refresh" after fixing the conflicts or'
1101                           ' revert the operation with "push --undo".')
1102
1103         append_string(self.__applied_file, name)
1104
1105         unapplied.remove(name)
1106         write_strings(self.__unapplied_file, unapplied)
1107
1108         if not ex:
1109             # if the merge was OK and no conflicts, just refresh the patch
1110             # The GIT cache was already updated by the merge operation
1111             if modified:
1112                 log = 'push(m)'
1113             else:
1114                 log = 'push'
1115             self.refresh_patch(bottom = head, cache_update = False, log = log)
1116         else:
1117             # we store the correctly merged files only for
1118             # tracking the conflict history. Note that the
1119             # git.merge() operations should always leave the index
1120             # in a valid state (i.e. only stage 0 files)
1121             self.refresh_patch(bottom = head, cache_update = False,
1122                                log = 'push(c)')
1123             raise StackException, str(ex)
1124
1125         return modified
1126
1127     def undo_push(self):
1128         name = self.get_current()
1129         assert(name)
1130
1131         patch = self.get_patch(name)
1132         old_bottom = patch.get_old_bottom()
1133         old_top = patch.get_old_top()
1134
1135         # the top of the patch is changed by a push operation only
1136         # together with the bottom (otherwise the top was probably
1137         # modified by 'refresh'). If they are both unchanged, there
1138         # was a fast forward
1139         if old_bottom == patch.get_bottom() and old_top != patch.get_top():
1140             raise StackException, 'No undo information available'
1141
1142         git.reset()
1143         self.pop_patch(name)
1144         ret = patch.restore_old_boundaries()
1145         if ret:
1146             self.log_patch(patch, 'undo')
1147
1148         return ret
1149
1150     def pop_patch(self, name, keep = False):
1151         """Pops the top patch from the stack
1152         """
1153         applied = self.get_applied()
1154         applied.reverse()
1155         assert(name in applied)
1156
1157         patch = self.get_patch(name)
1158
1159         if git.get_head_file() == self.get_name():
1160             if keep and not git.apply_diff(git.get_head(), patch.get_bottom()):
1161                 raise StackException(
1162                     'Failed to pop patches while preserving the local changes')
1163             git.switch(patch.get_bottom(), keep)
1164         else:
1165             git.set_branch(self.get_name(), patch.get_bottom())
1166
1167         # save the new applied list
1168         idx = applied.index(name) + 1
1169
1170         popped = applied[:idx]
1171         popped.reverse()
1172         unapplied = popped + self.get_unapplied()
1173         write_strings(self.__unapplied_file, unapplied)
1174
1175         del applied[:idx]
1176         applied.reverse()
1177         write_strings(self.__applied_file, applied)
1178
1179     def empty_patch(self, name):
1180         """Returns True if the patch is empty
1181         """
1182         self.__patch_name_valid(name)
1183         patch = self.get_patch(name)
1184         bottom = patch.get_bottom()
1185         top = patch.get_top()
1186
1187         if bottom == top:
1188             return True
1189         elif git.get_commit(top).get_tree() \
1190                  == git.get_commit(bottom).get_tree():
1191             return True
1192
1193         return False
1194
1195     def rename_patch(self, oldname, newname):
1196         self.__patch_name_valid(newname)
1197
1198         applied = self.get_applied()
1199         unapplied = self.get_unapplied()
1200
1201         if oldname == newname:
1202             raise StackException, '"To" name and "from" name are the same'
1203
1204         if newname in applied or newname in unapplied:
1205             raise StackException, 'Patch "%s" already exists' % newname
1206
1207         if oldname in unapplied:
1208             self.get_patch(oldname).rename(newname)
1209             unapplied[unapplied.index(oldname)] = newname
1210             write_strings(self.__unapplied_file, unapplied)
1211         elif oldname in applied:
1212             self.get_patch(oldname).rename(newname)
1213
1214             applied[applied.index(oldname)] = newname
1215             write_strings(self.__applied_file, applied)
1216         else:
1217             raise StackException, 'Unknown patch "%s"' % oldname
1218
1219     def log_patch(self, patch, message, notes = None):
1220         """Generate a log commit for a patch
1221         """
1222         top = git.get_commit(patch.get_top())
1223         old_log = patch.get_log()
1224
1225         if message is None:
1226             # replace the current log entry
1227             if not old_log:
1228                 raise StackException, \
1229                       'No log entry to annotate for patch "%s"' \
1230                       % patch.get_name()
1231             replace = True
1232             log_commit = git.get_commit(old_log)
1233             msg = log_commit.get_log().split('\n')[0]
1234             log_parent = log_commit.get_parent()
1235             if log_parent:
1236                 parents = [log_parent]
1237             else:
1238                 parents = []
1239         else:
1240             # generate a new log entry
1241             replace = False
1242             msg = '%s\t%s' % (message, top.get_id_hash())
1243             if old_log:
1244                 parents = [old_log]
1245             else:
1246                 parents = []
1247
1248         if notes:
1249             msg += '\n\n' + notes
1250
1251         log = git.commit(message = msg, parents = parents,
1252                          cache_update = False, tree_id = top.get_tree(),
1253                          allowempty = True)
1254         patch.set_log(log)
1255
1256     def hide_patch(self, name):
1257         """Add the patch to the hidden list.
1258         """
1259         unapplied = self.get_unapplied()
1260         if name not in unapplied:
1261             # keep the checking order for backward compatibility with
1262             # the old hidden patches functionality
1263             if self.patch_applied(name):
1264                 raise StackException, 'Cannot hide applied patch "%s"' % name
1265             elif self.patch_hidden(name):
1266                 raise StackException, 'Patch "%s" already hidden' % name
1267             else:
1268                 raise StackException, 'Unknown patch "%s"' % name
1269
1270         if not self.patch_hidden(name):
1271             # check needed for backward compatibility with the old
1272             # hidden patches functionality
1273             append_string(self.__hidden_file, name)
1274
1275         unapplied.remove(name)
1276         write_strings(self.__unapplied_file, unapplied)
1277
1278     def unhide_patch(self, name):
1279         """Remove the patch from the hidden list.
1280         """
1281         hidden = self.get_hidden()
1282         if not name in hidden:
1283             if self.patch_applied(name) or self.patch_unapplied(name):
1284                 raise StackException, 'Patch "%s" not hidden' % name
1285             else:
1286                 raise StackException, 'Unknown patch "%s"' % name
1287
1288         hidden.remove(name)
1289         write_strings(self.__hidden_file, hidden)
1290
1291         if not self.patch_applied(name) and not self.patch_unapplied(name):
1292             # check needed for backward compatibility with the old
1293             # hidden patches functionality
1294             append_string(self.__unapplied_file, name)