]> granicus.if.org Git - python/commitdiff
closes bpo-37347: Fix refcount problem in sqlite3. (GH-14268)
authorgescheit <gescheit@yandex-team.ru>
Sat, 13 Jul 2019 03:15:49 +0000 (06:15 +0300)
committerBenjamin Peterson <benjamin@python.org>
Sat, 13 Jul 2019 03:15:48 +0000 (20:15 -0700)
Lib/sqlite3/test/regression.py
Misc/ACKS
Misc/NEWS.d/next/Library/2019-06-20-14-23-48.bpo-37347.Gf9yYI.rst [new file with mode: 0644]
Modules/_sqlite/connection.c
Modules/_sqlite/connection.h
setup.py

index ee326168bc125a6c4761be7fe91b6d67db782e98..c714116ac49208bd1cd31fbe985b003bb81c82f8 100644 (file)
@@ -25,6 +25,7 @@ import datetime
 import unittest
 import sqlite3 as sqlite
 import weakref
+import functools
 from test import support
 
 class RegressionTests(unittest.TestCase):
@@ -383,72 +384,26 @@ class RegressionTests(unittest.TestCase):
         with self.assertRaises(AttributeError):
             del self.con.isolation_level
 
+    def CheckBpo37347(self):
+        class Printer:
+            def log(self, *args):
+                return sqlite.SQLITE_OK
 
-class UnhashableFunc:
-    __hash__ = None
+        for method in [self.con.set_trace_callback,
+                       functools.partial(self.con.set_progress_handler, n=1),
+                       self.con.set_authorizer]:
+            printer_instance = Printer()
+            method(printer_instance.log)
+            method(printer_instance.log)
+            self.con.execute("select 1")  # trigger seg fault
+            method(None)
 
-    def __init__(self, return_value=None):
-        self.calls = 0
-        self.return_value = return_value
-
-    def __call__(self, *args, **kwargs):
-        self.calls += 1
-        return self.return_value
-
-
-class UnhashableCallbacksTestCase(unittest.TestCase):
-    """
-    https://bugs.python.org/issue34052
-
-    Registering unhashable callbacks raises TypeError, callbacks are not
-    registered in SQLite after such registration attempt.
-    """
-    def setUp(self):
-        self.con = sqlite.connect(':memory:')
-
-    def tearDown(self):
-        self.con.close()
-
-    def test_progress_handler(self):
-        f = UnhashableFunc(return_value=0)
-        with self.assertRaisesRegex(TypeError, 'unhashable type'):
-            self.con.set_progress_handler(f, 1)
-        self.con.execute('SELECT 1')
-        self.assertFalse(f.calls)
-
-    def test_func(self):
-        func_name = 'func_name'
-        f = UnhashableFunc()
-        with self.assertRaisesRegex(TypeError, 'unhashable type'):
-            self.con.create_function(func_name, 0, f)
-        msg = 'no such function: %s' % func_name
-        with self.assertRaisesRegex(sqlite.OperationalError, msg):
-            self.con.execute('SELECT %s()' % func_name)
-        self.assertFalse(f.calls)
-
-    def test_authorizer(self):
-        f = UnhashableFunc(return_value=sqlite.SQLITE_DENY)
-        with self.assertRaisesRegex(TypeError, 'unhashable type'):
-            self.con.set_authorizer(f)
-        self.con.execute('SELECT 1')
-        self.assertFalse(f.calls)
-
-    def test_aggr(self):
-        class UnhashableType(type):
-            __hash__ = None
-        aggr_name = 'aggr_name'
-        with self.assertRaisesRegex(TypeError, 'unhashable type'):
-            self.con.create_aggregate(aggr_name, 0, UnhashableType('Aggr', (), {}))
-        msg = 'no such function: %s' % aggr_name
-        with self.assertRaisesRegex(sqlite.OperationalError, msg):
-            self.con.execute('SELECT %s()' % aggr_name)
 
 
 def suite():
     regression_suite = unittest.makeSuite(RegressionTests, "Check")
     return unittest.TestSuite((
         regression_suite,
-        unittest.makeSuite(UnhashableCallbacksTestCase),
     ))
 
 def test():
index d916c45a8e441b46ced3acd188eac7df95b32c2a..c0119992cffe30445202c2c1c08b129e4c2ef02d 100644 (file)
--- a/Misc/ACKS
+++ b/Misc/ACKS
@@ -1870,3 +1870,4 @@ Diego Rojas
 Edison Abahurire
 Geoff Shannon
 Batuhan Taskaya
+Aleksandr Balezin
diff --git a/Misc/NEWS.d/next/Library/2019-06-20-14-23-48.bpo-37347.Gf9yYI.rst b/Misc/NEWS.d/next/Library/2019-06-20-14-23-48.bpo-37347.Gf9yYI.rst
new file mode 100644 (file)
index 0000000..1e61f5e
--- /dev/null
@@ -0,0 +1,6 @@
+:meth:`sqlite3.Connection.create_aggregate`,\r
+:meth:`sqlite3.Connection.create_function`,\r
+:meth:`sqlite3.Connection.set_authorizer`,\r
+:meth:`sqlite3.Connection.set_progress_handler` \r
+:meth:`sqlite3.Connection.set_trace_callback` \r
+methods lead to segfaults if some of these methods are called twice with an equal object but not the same. Now callbacks are stored more carefully. Patch by Aleksandr Balezin.
\ No newline at end of file
index 47b8b62cea26cf9dbe5d7c3397d2cd8181f470f4..ae03aed8f0b885b0151645d2a34c2221e9ea3c3b 100644 (file)
@@ -186,10 +186,9 @@ int pysqlite_connection_init(pysqlite_Connection* self, PyObject* args, PyObject
     }
     self->check_same_thread = check_same_thread;
 
-    Py_XSETREF(self->function_pinboard, PyDict_New());
-    if (!self->function_pinboard) {
-        return -1;
-    }
+    self->function_pinboard_trace_callback = NULL;
+    self->function_pinboard_progress_handler = NULL;
+    self->function_pinboard_authorizer_cb = NULL;
 
     Py_XSETREF(self->collations, PyDict_New());
     if (!self->collations) {
@@ -249,19 +248,18 @@ void pysqlite_connection_dealloc(pysqlite_Connection* self)
 
     /* Clean up if user has not called .close() explicitly. */
     if (self->db) {
-        Py_BEGIN_ALLOW_THREADS
         SQLITE3_CLOSE(self->db);
-        Py_END_ALLOW_THREADS
     }
 
     Py_XDECREF(self->isolation_level);
-    Py_XDECREF(self->function_pinboard);
+    Py_XDECREF(self->function_pinboard_trace_callback);
+    Py_XDECREF(self->function_pinboard_progress_handler);
+    Py_XDECREF(self->function_pinboard_authorizer_cb);
     Py_XDECREF(self->row_factory);
     Py_XDECREF(self->text_factory);
     Py_XDECREF(self->collations);
     Py_XDECREF(self->statements);
     Py_XDECREF(self->cursors);
-
     Py_TYPE(self)->tp_free((PyObject*)self);
 }
 
@@ -342,9 +340,7 @@ PyObject* pysqlite_connection_close(pysqlite_Connection* self, PyObject* args)
     pysqlite_do_all_statements(self, ACTION_FINALIZE, 1);
 
     if (self->db) {
-        Py_BEGIN_ALLOW_THREADS
         rc = SQLITE3_CLOSE(self->db);
-        Py_END_ALLOW_THREADS
 
         if (rc != SQLITE_OK) {
             _pysqlite_seterror(self->db, NULL);
@@ -808,6 +804,11 @@ static void _pysqlite_drop_unused_cursor_references(pysqlite_Connection* self)
     Py_SETREF(self->cursors, new_list);
 }
 
+static void _destructor(void* args)
+{
+    Py_DECREF((PyObject*)args);
+}
+
 PyObject* pysqlite_connection_create_function(pysqlite_Connection* self, PyObject* args, PyObject* kwargs)
 {
     static char *kwlist[] = {"name", "narg", "func", "deterministic", NULL};
@@ -843,17 +844,16 @@ PyObject* pysqlite_connection_create_function(pysqlite_Connection* self, PyObjec
         flags |= SQLITE_DETERMINISTIC;
 #endif
     }
-    if (PyDict_SetItem(self->function_pinboard, func, Py_None) == -1) {
-        return NULL;
-    }
-    rc = sqlite3_create_function(self->db,
-                                 name,
-                                 narg,
-                                 flags,
-                                 (void*)func,
-                                 _pysqlite_func_callback,
-                                 NULL,
-                                 NULL);
+    Py_INCREF(func);
+    rc = sqlite3_create_function_v2(self->db,
+                                    name,
+                                    narg,
+                                    flags,
+                                    (void*)func,
+                                    _pysqlite_func_callback,
+                                    NULL,
+                                    NULL,
+                                    &_destructor);  // will decref func
 
     if (rc != SQLITE_OK) {
         /* Workaround for SQLite bug: no error code or string is available here */
@@ -880,11 +880,16 @@ PyObject* pysqlite_connection_create_aggregate(pysqlite_Connection* self, PyObje
                                       kwlist, &name, &n_arg, &aggregate_class)) {
         return NULL;
     }
-
-    if (PyDict_SetItem(self->function_pinboard, aggregate_class, Py_None) == -1) {
-        return NULL;
-    }
-    rc = sqlite3_create_function(self->db, name, n_arg, SQLITE_UTF8, (void*)aggregate_class, 0, &_pysqlite_step_callback, &_pysqlite_final_callback);
+    Py_INCREF(aggregate_class);
+    rc = sqlite3_create_function_v2(self->db,
+                                    name,
+                                    n_arg,
+                                    SQLITE_UTF8,
+                                    (void*)aggregate_class,
+                                    0,
+                                    &_pysqlite_step_callback,
+                                    &_pysqlite_final_callback,
+                                    &_destructor); // will decref func
     if (rc != SQLITE_OK) {
         /* Workaround for SQLite bug: no error code or string is available here */
         PyErr_SetString(pysqlite_OperationalError, "Error creating aggregate");
@@ -1003,13 +1008,14 @@ static PyObject* pysqlite_connection_set_authorizer(pysqlite_Connection* self, P
         return NULL;
     }
 
-    if (PyDict_SetItem(self->function_pinboard, authorizer_cb, Py_None) == -1) {
-        return NULL;
-    }
     rc = sqlite3_set_authorizer(self->db, _authorizer_callback, (void*)authorizer_cb);
     if (rc != SQLITE_OK) {
         PyErr_SetString(pysqlite_OperationalError, "Error setting authorizer callback");
+        Py_XSETREF(self->function_pinboard_authorizer_cb, NULL);
         return NULL;
+    } else {
+        Py_INCREF(authorizer_cb);
+        Py_XSETREF(self->function_pinboard_authorizer_cb, authorizer_cb);
     }
     Py_RETURN_NONE;
 }
@@ -1033,12 +1039,12 @@ static PyObject* pysqlite_connection_set_progress_handler(pysqlite_Connection* s
     if (progress_handler == Py_None) {
         /* None clears the progress handler previously set */
         sqlite3_progress_handler(self->db, 0, 0, (void*)0);
+        Py_XSETREF(self->function_pinboard_progress_handler, NULL);
     } else {
-        if (PyDict_SetItem(self->function_pinboard, progress_handler, Py_None) == -1)
-            return NULL;
         sqlite3_progress_handler(self->db, n, _progress_handler, progress_handler);
+        Py_INCREF(progress_handler);
+        Py_XSETREF(self->function_pinboard_progress_handler, progress_handler);
     }
-
     Py_RETURN_NONE;
 }
 
@@ -1060,10 +1066,11 @@ static PyObject* pysqlite_connection_set_trace_callback(pysqlite_Connection* sel
     if (trace_callback == Py_None) {
         /* None clears the trace callback previously set */
         sqlite3_trace(self->db, 0, (void*)0);
+        Py_XSETREF(self->function_pinboard_trace_callback, NULL);
     } else {
-        if (PyDict_SetItem(self->function_pinboard, trace_callback, Py_None) == -1)
-            return NULL;
         sqlite3_trace(self->db, _trace_callback, trace_callback);
+        Py_INCREF(trace_callback);
+        Py_XSETREF(self->function_pinboard_trace_callback, trace_callback);
     }
 
     Py_RETURN_NONE;
index 4e9d94c5f3089135affe3173d74750a293b7552d..206085e00a00c707fee85abadf19c75da8819fd9 100644 (file)
@@ -85,11 +85,10 @@ typedef struct
      */
     PyObject* text_factory;
 
-    /* remember references to functions/classes used in
-     * create_function/create/aggregate, use these as dictionary keys, so we
-     * can keep the total system refcount constant by clearing that dictionary
-     * in connection_dealloc */
-    PyObject* function_pinboard;
+    /* remember references to object used in trace_callback/progress_handler/authorizer_cb */
+    PyObject* function_pinboard_trace_callback;
+    PyObject* function_pinboard_progress_handler;
+    PyObject* function_pinboard_authorizer_cb;
 
     /* a dictionary of registered collation name => collation callable mappings */
     PyObject* collations;
index 3ec89cedfd57599a80a9e57466b3a9fe068d943e..7cd77257ae35df73dc757a7bfc9d963466239ae6 100644 (file)
--- a/setup.py
+++ b/setup.py
@@ -1357,7 +1357,7 @@ class PyBuildExt(build_ext):
                              ]
         if CROSS_COMPILING:
             sqlite_inc_paths = []
-        MIN_SQLITE_VERSION_NUMBER = (3, 3, 9)
+        MIN_SQLITE_VERSION_NUMBER = (3, 7, 2)
         MIN_SQLITE_VERSION = ".".join([str(x)
                                     for x in MIN_SQLITE_VERSION_NUMBER])