chiark / gitweb /
Teach "stg import" to import patch series from tar archives
[stgit] / stgit / commands / imprt.py
index de5e9a506ef7f2f416cf096c99d09df4ffceb632..3eb29ba32dc5af41634d79a5a80894caf97a56cf 100644 (file)
@@ -15,7 +15,7 @@ along with this program; if not, write to the Free Software
 Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
 """
 
-import sys, os, re, email
+import sys, os, re, email, tarfile
 from mailbox import UnixMailbox
 from StringIO import StringIO
 from stgit.argparse import opt
@@ -51,7 +51,8 @@ options = [
     opt('-M', '--mbox', action = 'store_true',
         short = 'Import a series of patches from an mbox file'),
     opt('-s', '--series', action = 'store_true',
-        short = 'Import a series of patches'),
+        short = 'Import a series of patches', long = """
+        Import a series of patches from a series file or a tar archive."""),
     opt('-u', '--url', action = 'store_true',
         short = 'Import a patch from a URL'),
     opt('-n', '--name',
@@ -227,6 +228,9 @@ def __import_series(filename, options):
     applied = crt_series.get_applied()
 
     if filename:
+        if tarfile.is_tarfile(filename):
+            __import_tarfile(filename, options)
+            return
         f = file(filename)
         patchdir = os.path.dirname(filename)
     else:
@@ -280,6 +284,44 @@ def __import_url(url, options):
     urllib.urlretrieve(url, filename)
     __import_file(filename, options)
 
+def __import_tarfile(tar, options):
+    """Import patch series from a tar archive
+    """
+    import tempfile
+    import shutil
+
+    if not tarfile.is_tarfile(tar):
+        raise CmdException, "%s is not a tarfile!" % tar
+
+    t = tarfile.open(tar, 'r')
+    names = t.getnames()
+
+    # verify paths in the tarfile are safe
+    for n in names:
+        if n.startswith('/'):
+            raise CmdException, "Absolute path found in %s" % tar
+        if n.find("..") > -1:
+            raise CmdException, "Relative path found in %s" % tar
+
+    # find the series file
+    seriesfile = '';
+    for m in names:
+        if m.endswith('/series') or m == 'series':
+            seriesfile = m
+            break
+    if seriesfile == '':
+        raise CmdException, "no 'series' file found in %s" % tar
+
+    # unpack into a tmp dir
+    tmpdir = tempfile.mkdtemp('.stg')
+    t.extractall(tmpdir)
+
+    # apply the series
+    __import_series(os.path.join(tmpdir, seriesfile), options)
+
+    # cleanup the tmpdir
+    shutil.rmtree(tmpdir)
+
 def func(parser, options, args):
     """Import a GNU diff file as a new patch
     """