]> granicus.if.org Git - python/commitdiff
Make db modules' error classes inherit IOError.
authorGeorg Brandl <georg@python.org>
Wed, 28 May 2008 08:43:17 +0000 (08:43 +0000)
committerGeorg Brandl <georg@python.org>
Wed, 28 May 2008 08:43:17 +0000 (08:43 +0000)
Stop dbm from importing every dbm module when imported.

Lib/dbm/__init__.py
Lib/dbm/bsd.py
Lib/test/test_dbm.py
Modules/_dbmmodule.c
Modules/_gdbmmodule.c

index 9fdd4145cc19aba4fe6467109b14f3be82cccb0a..2082e07335705f29fc7d4ea85af0f9dfa03e84e3 100644 (file)
@@ -48,27 +48,26 @@ class error(Exception):
     pass
 
 _names = ['dbm.bsd', 'dbm.gnu', 'dbm.ndbm', 'dbm.dumb']
-_errors = [error]
 _defaultmod = None
 _modules = {}
 
-for _name in _names:
-    try:
-        _mod = __import__(_name, fromlist=['open'])
-    except ImportError:
-        continue
-    if not _defaultmod:
-        _defaultmod = _mod
-    _modules[_name] = _mod
-    _errors.append(_mod.error)
-
-if not _defaultmod:
-    raise ImportError("no dbm clone found; tried %s" % _names)
-
-error = tuple(_errors)
+error = (error, IOError)
 
 
 def open(file, flag = 'r', mode = 0o666):
+    global _defaultmod
+    if _defaultmod is None:
+        for name in _names:
+            try:
+                mod = __import__(name, fromlist=['open'])
+            except ImportError:
+                continue
+            if not _defaultmod:
+                _defaultmod = mod
+            _modules[name] = mod
+        if not _defaultmod:
+            raise ImportError("no dbm clone found; tried %s" % _names)
+
     # guess the type of an existing database
     result = whichdb(file)
     if result is None:
@@ -81,19 +80,14 @@ def open(file, flag = 'r', mode = 0o666):
     elif result == "":
         # db type cannot be determined
         raise error("db type could not be determined")
+    elif result not in _modules:
+        raise error("db type is {0}, but the module is not "
+                    "available".format(result))
     else:
         mod = _modules[result]
     return mod.open(file, flag, mode)
 
 
-try:
-    from dbm import ndbm
-    _dbmerror = ndbm.error
-except ImportError:
-    ndbm = None
-    # just some sort of valid exception which might be raised in the ndbm test
-    _dbmerror = IOError
-
 def whichdb(filename):
     """Guess which db package to use to open a db file.
 
@@ -129,7 +123,7 @@ def whichdb(filename):
                 d = ndbm.open(filename)
                 d.close()
                 return "dbm.ndbm"
-        except (IOError, _dbmerror):
+        except IOError:
             pass
 
     # Check for dumbdbm next -- this has a .dir and a .dat file
index 8353f5037661ab4377a7e9324ef9878016d0dc5a..2dccadb8b1e4666ed6fbd01e30e9a7a583386829 100644 (file)
@@ -4,7 +4,8 @@ import bsddb
 
 __all__ = ["error", "open"]
 
-error = bsddb.error
+class error(bsddb.error, IOError):
+    pass
 
 def open(file, flag = 'r', mode=0o666):
     return bsddb.hashopen(file, flag, mode)
index aab1388d79111e8d090e331ca444121872037a09..41c37cbea515bee9407f58c6a45bc1320e7f8aae 100644 (file)
@@ -14,11 +14,13 @@ _fname = test.support.TESTFN
 # setting dbm to use each in turn, and yielding that module
 #
 def dbm_iterator():
-    old_default = dbm._defaultmod
-    for module in dbm._modules.values():
-        dbm._defaultmod = module
-        yield module
-    dbm._defaultmod = old_default
+    for name in dbm._names:
+        try:
+            mod = __import__(name, fromlist=['open'])
+        except ImportError:
+            continue
+        dbm._modules[name] = mod
+        yield mod
 
 #
 # Clean up all scratch databases we might have created during testing
@@ -40,8 +42,20 @@ class AnyDBMTestCase(unittest.TestCase):
              'g': b'intended',
              }
 
-    def __init__(self, *args):
-        unittest.TestCase.__init__(self, *args)
+    def init_db(self):
+        f = dbm.open(_fname, 'n')
+        for k in self._dict:
+            f[k.encode("ascii")] = self._dict[k]
+        f.close()
+
+    def keys_helper(self, f):
+        keys = sorted(k.decode("ascii") for k in f.keys())
+        dkeys = sorted(self._dict.keys())
+        self.assertEqual(keys, dkeys)
+        return keys
+
+    def test_error(self):
+        self.assert_(issubclass(self.module.error, IOError))
 
     def test_anydbm_creation(self):
         f = dbm.open(_fname, 'c')
@@ -83,22 +97,11 @@ class AnyDBMTestCase(unittest.TestCase):
         for key in self._dict:
             self.assertEqual(self._dict[key], f[key.encode("ascii")])
 
-    def init_db(self):
-        f = dbm.open(_fname, 'n')
-        for k in self._dict:
-            f[k.encode("ascii")] = self._dict[k]
-        f.close()
-
-    def keys_helper(self, f):
-        keys = sorted(k.decode("ascii") for k in f.keys())
-        dkeys = sorted(self._dict.keys())
-        self.assertEqual(keys, dkeys)
-        return keys
-
     def tearDown(self):
         delete_files()
 
     def setUp(self):
+        dbm._defaultmod = self.module
         delete_files()
 
 
@@ -137,11 +140,11 @@ class WhichDBTestCase(unittest.TestCase):
 
 
 def test_main():
-    try:
-        for module in dbm_iterator():
-            test.support.run_unittest(AnyDBMTestCase, WhichDBTestCase)
-    finally:
-        delete_files()
+    classes = [WhichDBTestCase]
+    for mod in dbm_iterator():
+        classes.append(type("TestCase-" + mod.__name__, (AnyDBMTestCase,),
+                            {'module': mod}))
+    test.support.run_unittest(*classes)
 
 if __name__ == "__main__":
     test_main()
index ddfd4cd759aae8bb57f8a11562f3bba12488be4f..7e80381db77e4f37dbeb256ab2674a975a76223f 100644 (file)
@@ -401,7 +401,8 @@ init_dbm(void) {
                return;
        d = PyModule_GetDict(m);
        if (DbmError == NULL)
-               DbmError = PyErr_NewException("_dbm.error", NULL, NULL);
+               DbmError = PyErr_NewException("_dbm.error",
+                                             PyExc_IOError, NULL);
        s = PyUnicode_FromString(which_dbm);
        if (s != NULL) {
                PyDict_SetItemString(d, "library", s);
index 6c7581969ae2c536955d042e553e0efc155b5145..abc88370911e7ece85eced0c2294f425b3d6a4e7 100644 (file)
@@ -523,7 +523,7 @@ init_gdbm(void) {
     if (m == NULL)
        return;
     d = PyModule_GetDict(m);
-    DbmError = PyErr_NewException("_gdbm.error", NULL, NULL);
+    DbmError = PyErr_NewException("_gdbm.error", PyExc_IOError, NULL);
     if (DbmError != NULL) {
         PyDict_SetItemString(d, "error", DbmError);
         s = PyUnicode_FromString(dbmmodule_open_flags);