]> granicus.if.org Git - python/commitdiff
Refactor recently added bugfix into more testable code by using a
authorGregory P. Smith <greg@krypto.org>
Sun, 3 Feb 2013 08:36:32 +0000 (00:36 -0800)
committerGregory P. Smith <greg@krypto.org>
Sun, 3 Feb 2013 08:36:32 +0000 (00:36 -0800)
method for windows file name sanitization.  Splits the unittest up
into several based on platform.

Lib/test/test_zipfile.py
Lib/zipfile.py

index c1e20b286544b534b7219b8db5ec7d482260b429..5e837cd6de667a0de0d19646ccac1ca02316ab3e 100644 (file)
@@ -538,8 +538,15 @@ class TestsWithSourceFile(unittest.TestCase):
         with open(filename, 'rb') as f:
             self.assertEqual(f.read(), content)
 
-    def test_extract_hackers_arcnames(self):
-        hacknames = [
+    def test_sanitize_windows_name(self):
+        san = zipfile.ZipFile._sanitize_windows_name
+        # Passing pathsep in allows this test to work regardless of platform.
+        self.assertEqual(san(r',,?,C:,foo,bar/z', ','), r'_,C_,foo,bar/z')
+        self.assertEqual(san(r'a\b,c<d>e|f"g?h*i', ','), r'a\b,c_d_e_f_g_h_i')
+        self.assertEqual(san('../../foo../../ba..r', '/'), r'foo/ba..r')
+
+    def test_extract_hackers_arcnames_common_cases(self):
+        common_hacknames = [
             ('../foo/bar', 'foo/bar'),
             ('foo/../bar', 'foo/bar'),
             ('foo/../../bar', 'foo/bar'),
@@ -549,8 +556,12 @@ class TestsWithSourceFile(unittest.TestCase):
             ('/foo/../bar', 'foo/bar'),
             ('/foo/../../bar', 'foo/bar'),
         ]
-        if os.path.sep == '\\':  # Windows.
-            hacknames.extend([
+        self._test_extract_hackers_arcnames(common_hacknames)
+
+    @unittest.skipIf(os.path.sep != '\\', 'Requires \\ as path separator.')
+    def test_extract_hackers_arcnames_windows_only(self):
+        """Test combination of path fixing and windows name sanitization."""
+        windows_hacknames = [
                 (r'..\foo\bar', 'foo/bar'),
                 (r'..\/foo\/bar', 'foo/bar'),
                 (r'foo/\..\/bar', 'foo/bar'),
@@ -570,14 +581,19 @@ class TestsWithSourceFile(unittest.TestCase):
                 (r'C:/../C:/foo/bar', 'C_/foo/bar'),
                 (r'a:b\c<d>e|f"g?h*i', 'b/c_d_e_f_g_h_i'),
                 ('../../foo../../ba..r', 'foo/ba..r'),
-            ])
-        else:  # Unix
-            hacknames.extend([
-                ('//foo/bar', 'foo/bar'),
-                ('../../foo../../ba..r', 'foo../ba..r'),
-                (r'foo/..\bar', r'foo/..\bar'),
-            ])
+        ]
+        self._test_extract_hackers_arcnames(windows_hacknames)
+
+    @unittest.skipIf(os.path.sep != '/', r'Requires / as path separator.')
+    def test_extract_hackers_arcnames_posix_only(self):
+        posix_hacknames = [
+            ('//foo/bar', 'foo/bar'),
+            ('../../foo../../ba..r', 'foo../ba..r'),
+            (r'foo/..\bar', r'foo/..\bar'),
+        ]
+        self._test_extract_hackers_arcnames(posix_hacknames)
 
+    def _test_extract_hackers_arcnames(self, hacknames):
         for arcname, fixedname in hacknames:
             content = b'foobar' + arcname.encode()
             with zipfile.ZipFile(TESTFN2, 'w', zipfile.ZIP_STORED) as zipfp:
@@ -594,7 +610,8 @@ class TestsWithSourceFile(unittest.TestCase):
             with zipfile.ZipFile(TESTFN2, 'r') as zipfp:
                 writtenfile = zipfp.extract(arcname, targetpath)
                 self.assertEqual(writtenfile, correctfile,
-                                 msg="extract %r" % arcname)
+                                 msg='extract %r: %r != %r' %
+                                 (arcname, writtenfile, correctfile))
             self.check_file(correctfile, content)
             shutil.rmtree('target')
 
index 8b355d6acdbd6d6b5f1c53bbbe7b748d09578d46..3448c61795d7abaabe514e6f9d22494d1d106263 100644 (file)
@@ -883,6 +883,7 @@ class ZipFile:
     """
 
     fp = None                   # Set here since __del__ checks it
+    _windows_illegal_name_trans_table = None
 
     def __init__(self, file, mode="r", compression=ZIP_STORED, allowZip64=False):
         """Open the ZIP file with mode read "r", write "w" or append "a"."""
@@ -1223,6 +1224,21 @@ class ZipFile:
         for zipinfo in members:
             self.extract(zipinfo, path, pwd)
 
+    @classmethod
+    def _sanitize_windows_name(cls, arcname, pathsep):
+        """Replace bad characters and remove trailing dots from parts."""
+        table = cls._windows_illegal_name_trans_table
+        if not table:
+            illegal = ':<>|"?*'
+            table = str.maketrans(illegal, '_' * len(illegal))
+            cls._windows_illegal_name_trans_table = table
+        arcname = arcname.translate(table)
+        # remove trailing dots
+        arcname = (x.rstrip('.') for x in arcname.split(pathsep))
+        # rejoin, removing empty parts.
+        arcname = pathsep.join(x for x in arcname if x)
+        return arcname
+
     def _extract_member(self, member, targetpath, pwd):
         """Extract the ZipInfo object 'member' to a physical
            file on the path targetpath.
@@ -1236,16 +1252,12 @@ class ZipFile:
         # interpret absolute pathname as relative, remove drive letter or
         # UNC path, redundant separators, "." and ".." components.
         arcname = os.path.splitdrive(arcname)[1]
+        invalid_path_parts = ('', os.path.curdir, os.path.pardir)
         arcname = os.path.sep.join(x for x in arcname.split(os.path.sep)
-                    if x not in ('', os.path.curdir, os.path.pardir))
+                                   if x not in invalid_path_parts)
         if os.path.sep == '\\':
             # filter illegal characters on Windows
-            illegal = ':<>|"?*'
-            table = str.maketrans(illegal, '_' * len(illegal))
-            arcname = arcname.translate(table)
-            # remove trailing dots
-            arcname = (x.rstrip('.') for x in arcname.split(os.path.sep))
-            arcname = os.path.sep.join(x for x in arcname if x)
+            arcname = self._sanitize_windows_name(arcname, os.path.sep)
 
         targetpath = os.path.join(targetpath, arcname)
         targetpath = os.path.normpath(targetpath)