chiark / gitweb /
2994eb11a726da75f5e3d60e5b6d7c2289a88ee6
[stgit] / stgit / commands / common.py
1 """Function/variables common to all the commands
2 """
3
4 __copyright__ = """
5 Copyright (C) 2005, Catalin Marinas <catalin.marinas@gmail.com>
6
7 This program is free software; you can redistribute it and/or modify
8 it under the terms of the GNU General Public License version 2 as
9 published by the Free Software Foundation.
10
11 This program is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14 GNU General Public License for more details.
15
16 You should have received a copy of the GNU General Public License
17 along with this program; if not, write to the Free Software
18 Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
19 """
20
21 import sys, os, os.path, re
22 from optparse import OptionParser, make_option
23
24 from stgit.exception import *
25 from stgit.utils import *
26 from stgit.out import *
27 from stgit.run import *
28 from stgit import stack, git, basedir
29 from stgit.config import config, file_extensions
30 from stgit.lib import stack as libstack
31 from stgit.lib import git as libgit
32
33 # Command exception class
34 class CmdException(StgException):
35     pass
36
37 # Utility functions
38 class RevParseException(StgException):
39     """Revision spec parse error."""
40     pass
41
42 def parse_rev(rev):
43     """Parse a revision specification into its
44     patchname@branchname//patch_id parts. If no branch name has a slash
45     in it, also accept / instead of //."""
46     if '/' in ''.join(git.get_heads()):
47         # We have branch names with / in them.
48         branch_chars = r'[^@]'
49         patch_id_mark = r'//'
50     else:
51         # No / in branch names.
52         branch_chars = r'[^@/]'
53         patch_id_mark = r'(/|//)'
54     patch_re = r'(?P<patch>[^@/]+)'
55     branch_re = r'@(?P<branch>%s+)' % branch_chars
56     patch_id_re = r'%s(?P<patch_id>[a-z.]*)' % patch_id_mark
57
58     # Try //patch_id.
59     m = re.match(r'^%s$' % patch_id_re, rev)
60     if m:
61         return None, None, m.group('patch_id')
62
63     # Try path[@branch]//patch_id.
64     m = re.match(r'^%s(%s)?%s$' % (patch_re, branch_re, patch_id_re), rev)
65     if m:
66         return m.group('patch'), m.group('branch'), m.group('patch_id')
67
68     # Try patch[@branch].
69     m = re.match(r'^%s(%s)?$' % (patch_re, branch_re), rev)
70     if m:
71         return m.group('patch'), m.group('branch'), None
72
73     # No, we can't parse that.
74     raise RevParseException
75
76 def git_id(crt_series, rev):
77     """Return the GIT id
78     """
79     if not rev:
80         return None
81
82     # try a GIT revision first
83     try:
84         return git.rev_parse(rev + '^{commit}')
85     except git.GitException:
86         pass
87
88     # try an StGIT patch name
89     try:
90         patch, branch, patch_id = parse_rev(rev)
91         if branch == None:
92             series = crt_series
93         else:
94             series = stack.Series(branch)
95         if patch == None:
96             patch = series.get_current()
97             if not patch:
98                 raise CmdException, 'No patches applied'
99         if patch in series.get_applied() or patch in series.get_unapplied() or \
100                patch in series.get_hidden():
101             if patch_id in ['top', '', None]:
102                 return series.get_patch(patch).get_top()
103             elif patch_id == 'bottom':
104                 return series.get_patch(patch).get_bottom()
105             elif patch_id == 'top.old':
106                 return series.get_patch(patch).get_old_top()
107             elif patch_id == 'bottom.old':
108                 return series.get_patch(patch).get_old_bottom()
109             elif patch_id == 'log':
110                 return series.get_patch(patch).get_log()
111         if patch == 'base' and patch_id == None:
112             return series.get_base()
113     except RevParseException:
114         pass
115     except stack.StackException:
116         pass
117
118     raise CmdException, 'Unknown patch or revision: %s' % rev
119
120 def git_commit(name, repository, branch = None):
121     """Return the a Commit object if 'name' is a patch name or Git commit.
122     The patch names allowed are in the form '<branch>:<patch>' and can
123     be followed by standard symbols used by git-rev-parse. If <patch>
124     is '{base}', it represents the bottom of the stack.
125     """
126     # Try a [branch:]patch name first
127     try:
128         branch, patch = name.split(':', 1)
129     except ValueError:
130         patch = name
131     if not branch:
132         branch = repository.current_branch_name
133
134     # The stack base
135     if patch.startswith('{base}'):
136         base_id = repository.get_stack(branch).base.sha1
137         return repository.rev_parse(base_id +
138                                     strip_prefix('{base}', patch))
139
140     # Other combination of branch and patch
141     try:
142         return repository.rev_parse('patches/%s/%s' % (branch, patch),
143                                     discard_stderr = True)
144     except libgit.RepositoryException:
145         pass
146
147     # Try a Git commit
148     try:
149         return repository.rev_parse(name, discard_stderr = True)
150     except libgit.RepositoryException:
151         raise CmdException('%s: Unknown patch or revision name' % name)
152
153 def check_local_changes():
154     if git.local_changes():
155         raise CmdException('local changes in the tree. Use "refresh" or'
156                            ' "status --reset"')
157
158 def check_head_top_equal(crt_series):
159     if not crt_series.head_top_equal():
160         raise CmdException('HEAD and top are not the same. This can happen'
161                            ' if you modify a branch with git. "stg repair'
162                            ' --help" explains more about what to do next.')
163
164 def check_conflicts():
165     if git.get_conflicts():
166         raise CmdException('Unsolved conflicts. Please resolve them first'
167                            ' or revert the changes with "status --reset"')
168
169 def print_crt_patch(crt_series, branch = None):
170     if not branch:
171         patch = crt_series.get_current()
172     else:
173         patch = stack.Series(branch).get_current()
174
175     if patch:
176         out.info('Now at patch "%s"' % patch)
177     else:
178         out.info('No patches applied')
179
180 def resolved_all(reset = None):
181     conflicts = git.get_conflicts()
182     git.resolved(conflicts, reset)
183
184 def push_patches(crt_series, patches, check_merged = False):
185     """Push multiple patches onto the stack. This function is shared
186     between the push and pull commands
187     """
188     forwarded = crt_series.forward_patches(patches)
189     if forwarded > 1:
190         out.info('Fast-forwarded patches "%s" - "%s"'
191                  % (patches[0], patches[forwarded - 1]))
192     elif forwarded == 1:
193         out.info('Fast-forwarded patch "%s"' % patches[0])
194
195     names = patches[forwarded:]
196
197     # check for patches merged upstream
198     if names and check_merged:
199         out.start('Checking for patches merged upstream')
200
201         merged = crt_series.merged_patches(names)
202
203         out.done('%d found' % len(merged))
204     else:
205         merged = []
206
207     for p in names:
208         out.start('Pushing patch "%s"' % p)
209
210         if p in merged:
211             crt_series.push_empty_patch(p)
212             out.done('merged upstream')
213         else:
214             modified = crt_series.push_patch(p)
215
216             if crt_series.empty_patch(p):
217                 out.done('empty patch')
218             elif modified:
219                 out.done('modified')
220             else:
221                 out.done()
222
223 def pop_patches(crt_series, patches, keep = False):
224     """Pop the patches in the list from the stack. It is assumed that
225     the patches are listed in the stack reverse order.
226     """
227     if len(patches) == 0:
228         out.info('Nothing to push/pop')
229     else:
230         p = patches[-1]
231         if len(patches) == 1:
232             out.start('Popping patch "%s"' % p)
233         else:
234             out.start('Popping patches "%s" - "%s"' % (patches[0], p))
235         crt_series.pop_patch(p, keep)
236         out.done()
237
238 def parse_patches(patch_args, patch_list, boundary = 0, ordered = False):
239     """Parse patch_args list for patch names in patch_list and return
240     a list. The names can be individual patches and/or in the
241     patch1..patch2 format.
242     """
243     # in case it receives a tuple
244     patch_list = list(patch_list)
245     patches = []
246
247     for name in patch_args:
248         pair = name.split('..')
249         for p in pair:
250             if p and not p in patch_list:
251                 raise CmdException, 'Unknown patch name: %s' % p
252
253         if len(pair) == 1:
254             # single patch name
255             pl = pair
256         elif len(pair) == 2:
257             # patch range [p1]..[p2]
258             # inclusive boundary
259             if pair[0]:
260                 first = patch_list.index(pair[0])
261             else:
262                 first = -1
263             # exclusive boundary
264             if pair[1]:
265                 last = patch_list.index(pair[1]) + 1
266             else:
267                 last = -1
268
269             # only cross the boundary if explicitly asked
270             if not boundary:
271                 boundary = len(patch_list)
272             if first < 0:
273                 if last <= boundary:
274                     first = 0
275                 else:
276                     first = boundary
277             if last < 0:
278                 if first < boundary:
279                     last = boundary
280                 else:
281                     last = len(patch_list)
282
283             if last > first:
284                 pl = patch_list[first:last]
285             else:
286                 pl = patch_list[(last - 1):(first + 1)]
287                 pl.reverse()
288         else:
289             raise CmdException, 'Malformed patch name: %s' % name
290
291         for p in pl:
292             if p in patches:
293                 raise CmdException, 'Duplicate patch name: %s' % p
294
295         patches += pl
296
297     if ordered:
298         patches = [p for p in patch_list if p in patches]
299
300     return patches
301
302 def name_email(address):
303     p = parse_name_email(address)
304     if p:
305         return p
306     else:
307         raise CmdException('Incorrect "name <email>"/"email (name)" string: %s'
308                            % address)
309
310 def name_email_date(address):
311     p = parse_name_email_date(address)
312     if p:
313         return p
314     else:
315         raise CmdException('Incorrect "name <email> date" string: %s' % address)
316
317 def address_or_alias(addr_str):
318     """Return the address if it contains an e-mail address or look up
319     the aliases in the config files.
320     """
321     def __address_or_alias(addr):
322         if not addr:
323             return None
324         if addr.find('@') >= 0:
325             # it's an e-mail address
326             return addr
327         alias = config.get('mail.alias.'+addr)
328         if alias:
329             # it's an alias
330             return alias
331         raise CmdException, 'unknown e-mail alias: %s' % addr
332
333     addr_list = [__address_or_alias(addr.strip())
334                  for addr in addr_str.split(',')]
335     return ', '.join([addr for addr in addr_list if addr])
336
337 def prepare_rebase(crt_series):
338     # pop all patches
339     applied = crt_series.get_applied()
340     if len(applied) > 0:
341         out.start('Popping all applied patches')
342         crt_series.pop_patch(applied[0])
343         out.done()
344     return applied
345
346 def rebase(crt_series, target):
347     try:
348         tree_id = git_id(crt_series, target)
349     except:
350         # it might be that we use a custom rebase command with its own
351         # target type
352         tree_id = target
353     if tree_id == git.get_head():
354         out.info('Already at "%s", no need for rebasing.' % target)
355         return
356     if target:
357         out.start('Rebasing to "%s"' % target)
358     else:
359         out.start('Rebasing to the default target')
360     git.rebase(tree_id = tree_id)
361     out.done()
362
363 def post_rebase(crt_series, applied, nopush, merged):
364     # memorize that we rebased to here
365     crt_series._set_field('orig-base', git.get_head())
366     # push the patches back
367     if not nopush:
368         push_patches(crt_series, applied, merged)
369
370 #
371 # Patch description/e-mail/diff parsing
372 #
373 def __end_descr(line):
374     return re.match('---\s*$', line) or re.match('diff -', line) or \
375             re.match('Index: ', line)
376
377 def __split_descr_diff(string):
378     """Return the description and the diff from the given string
379     """
380     descr = diff = ''
381     top = True
382
383     for line in string.split('\n'):
384         if top:
385             if not __end_descr(line):
386                 descr += line + '\n'
387                 continue
388             else:
389                 top = False
390         diff += line + '\n'
391
392     return (descr.rstrip(), diff)
393
394 def __parse_description(descr):
395     """Parse the patch description and return the new description and
396     author information (if any).
397     """
398     subject = body = ''
399     authname = authemail = authdate = None
400
401     descr_lines = [line.rstrip() for line in  descr.split('\n')]
402     if not descr_lines:
403         raise CmdException, "Empty patch description"
404
405     lasthdr = 0
406     end = len(descr_lines)
407
408     # Parse the patch header
409     for pos in range(0, end):
410         if not descr_lines[pos]:
411            continue
412         # check for a "From|Author:" line
413         if re.match('\s*(?:from|author):\s+', descr_lines[pos], re.I):
414             auth = re.findall('^.*?:\s+(.*)$', descr_lines[pos])[0]
415             authname, authemail = name_email(auth)
416             lasthdr = pos + 1
417             continue
418         # check for a "Date:" line
419         if re.match('\s*date:\s+', descr_lines[pos], re.I):
420             authdate = re.findall('^.*?:\s+(.*)$', descr_lines[pos])[0]
421             lasthdr = pos + 1
422             continue
423         if subject:
424             break
425         # get the subject
426         subject = descr_lines[pos]
427         lasthdr = pos + 1
428
429     # get the body
430     if lasthdr < end:
431         body = reduce(lambda x, y: x + '\n' + y, descr_lines[lasthdr:], '')
432
433     return (subject + body, authname, authemail, authdate)
434
435 def parse_mail(msg):
436     """Parse the message object and return (description, authname,
437     authemail, authdate, diff)
438     """
439     from email.Header import decode_header, make_header
440
441     def __decode_header(header):
442         """Decode a qp-encoded e-mail header as per rfc2047"""
443         try:
444             words_enc = decode_header(header)
445             hobj = make_header(words_enc)
446         except Exception, ex:
447             raise CmdException, 'header decoding error: %s' % str(ex)
448         return unicode(hobj).encode('utf-8')
449
450     # parse the headers
451     if msg.has_key('from'):
452         authname, authemail = name_email(__decode_header(msg['from']))
453     else:
454         authname = authemail = None
455
456     # '\n\t' can be found on multi-line headers
457     descr = __decode_header(msg['subject']).replace('\n\t', ' ')
458     authdate = msg['date']
459
460     # remove the '[*PATCH*]' expression in the subject
461     if descr:
462         descr = re.findall('^(\[.*?[Pp][Aa][Tt][Cc][Hh].*?\])?\s*(.*)$',
463                            descr)[0][1]
464     else:
465         raise CmdException, 'Subject: line not found'
466
467     # the rest of the message
468     msg_text = ''
469     for part in msg.walk():
470         if part.get_content_type() == 'text/plain':
471             msg_text += part.get_payload(decode = True)
472
473     rem_descr, diff = __split_descr_diff(msg_text)
474     if rem_descr:
475         descr += '\n\n' + rem_descr
476
477     # parse the description for author information
478     descr, descr_authname, descr_authemail, descr_authdate = \
479            __parse_description(descr)
480     if descr_authname:
481         authname = descr_authname
482     if descr_authemail:
483         authemail = descr_authemail
484     if descr_authdate:
485        authdate = descr_authdate
486
487     return (descr, authname, authemail, authdate, diff)
488
489 def parse_patch(text):
490     """Parse the input text and return (description, authname,
491     authemail, authdate, diff)
492     """
493     descr, diff = __split_descr_diff(text)
494     descr, authname, authemail, authdate = __parse_description(descr)
495
496     # we don't yet have an agreed place for the creation date.
497     # Just return None
498     return (descr, authname, authemail, authdate, diff)
499
500 def readonly_constant_property(f):
501     """Decorator that converts a function that computes a value to an
502     attribute that returns the value. The value is computed only once,
503     the first time it is accessed."""
504     def new_f(self):
505         n = '__' + f.__name__
506         if not hasattr(self, n):
507             setattr(self, n, f(self))
508         return getattr(self, n)
509     return property(new_f)
510
511 class DirectoryException(StgException):
512     pass
513
514 class _Directory(object):
515     def __init__(self, needs_current_series = True):
516         self.needs_current_series =  needs_current_series
517     @readonly_constant_property
518     def git_dir(self):
519         try:
520             return Run('git', 'rev-parse', '--git-dir'
521                        ).discard_stderr().output_one_line()
522         except RunException:
523             raise DirectoryException('No git repository found')
524     @readonly_constant_property
525     def __topdir_path(self):
526         try:
527             lines = Run('git', 'rev-parse', '--show-cdup'
528                         ).discard_stderr().output_lines()
529             if len(lines) == 0:
530                 return '.'
531             elif len(lines) == 1:
532                 return lines[0]
533             else:
534                 raise RunException('Too much output')
535         except RunException:
536             raise DirectoryException('No git repository found')
537     @readonly_constant_property
538     def is_inside_git_dir(self):
539         return { 'true': True, 'false': False
540                  }[Run('git', 'rev-parse', '--is-inside-git-dir'
541                        ).output_one_line()]
542     @readonly_constant_property
543     def is_inside_worktree(self):
544         return { 'true': True, 'false': False
545                  }[Run('git', 'rev-parse', '--is-inside-work-tree'
546                        ).output_one_line()]
547     def cd_to_topdir(self):
548         os.chdir(self.__topdir_path)
549
550 class DirectoryAnywhere(_Directory):
551     def setup(self):
552         pass
553
554 class DirectoryHasRepository(_Directory):
555     def setup(self):
556         self.git_dir # might throw an exception
557
558 class DirectoryInWorktree(DirectoryHasRepository):
559     def setup(self):
560         DirectoryHasRepository.setup(self)
561         if not self.is_inside_worktree:
562             raise DirectoryException('Not inside a git worktree')
563
564 class DirectoryGotoToplevel(DirectoryInWorktree):
565     def setup(self):
566         DirectoryInWorktree.setup(self)
567         self.cd_to_topdir()
568
569 class DirectoryHasRepositoryLib(_Directory):
570     """For commands that use the new infrastructure in stgit.lib.*."""
571     def __init__(self):
572         self.needs_current_series = False
573     def setup(self):
574         # This will throw an exception if we don't have a repository.
575         self.repository = libstack.Repository.default()