From: Gerhard Häring Date: Fri, 29 Feb 2008 22:08:41 +0000 (+0000) Subject: Updated to pysqlite 2.4.1. Documentation additions will come later. X-Git-Tag: v2.6a1~8 X-Git-Url: https://granicus.if.org/sourcecode?a=commitdiff_plain;h=1cc60ed214d83b1901a9e68782559c18f705ff07;p=python Updated to pysqlite 2.4.1. Documentation additions will come later. --- diff --git a/Lib/sqlite3/test/dbapi.py b/Lib/sqlite3/test/dbapi.py index b08da9c5fa..b27486d5b8 100644 --- a/Lib/sqlite3/test/dbapi.py +++ b/Lib/sqlite3/test/dbapi.py @@ -1,7 +1,7 @@ #-*- coding: ISO-8859-1 -*- # pysqlite2/test/dbapi.py: tests for DB-API compliance # -# Copyright (C) 2004-2005 Gerhard Häring +# Copyright (C) 2004-2007 Gerhard Häring # # This file is part of pysqlite. # @@ -22,6 +22,7 @@ # 3. This notice may not be removed or altered from any source distribution. import unittest +import sys import threading import sqlite3 as sqlite @@ -223,12 +224,45 @@ class CursorTests(unittest.TestCase): except sqlite.ProgrammingError: pass + def CheckExecuteParamList(self): + self.cu.execute("insert into test(name) values ('foo')") + self.cu.execute("select name from test where name=?", ["foo"]) + row = self.cu.fetchone() + self.failUnlessEqual(row[0], "foo") + + def CheckExecuteParamSequence(self): + class L(object): + def __len__(self): + return 1 + def __getitem__(self, x): + assert x == 0 + return "foo" + + self.cu.execute("insert into test(name) values ('foo')") + self.cu.execute("select name from test where name=?", L()) + row = self.cu.fetchone() + self.failUnlessEqual(row[0], "foo") + def CheckExecuteDictMapping(self): self.cu.execute("insert into test(name) values ('foo')") self.cu.execute("select name from test where name=:name", {"name": "foo"}) row = self.cu.fetchone() self.failUnlessEqual(row[0], "foo") + def CheckExecuteDictMapping_Mapping(self): + # Test only works with Python 2.5 or later + if sys.version_info < (2, 5, 0): + return + + class D(dict): + def __missing__(self, key): + return "foo" + + self.cu.execute("insert into test(name) values ('foo')") + self.cu.execute("select name from test where name=:name", D()) + row = self.cu.fetchone() + self.failUnlessEqual(row[0], "foo") + def CheckExecuteDictMappingTooLittleArgs(self): self.cu.execute("insert into test(name) values ('foo')") try: @@ -378,6 +412,12 @@ class CursorTests(unittest.TestCase): res = self.cu.fetchmany(100) self.failUnlessEqual(res, []) + def CheckFetchmanyKwArg(self): + """Checks if fetchmany works with keyword arguments""" + self.cu.execute("select name from test") + res = self.cu.fetchmany(size=100) + self.failUnlessEqual(len(res), 1) + def CheckFetchall(self): self.cu.execute("select name from test") res = self.cu.fetchall() diff --git a/Lib/sqlite3/test/hooks.py b/Lib/sqlite3/test/hooks.py index cb0a6216fd..547dc65934 100644 --- a/Lib/sqlite3/test/hooks.py +++ b/Lib/sqlite3/test/hooks.py @@ -1,7 +1,7 @@ #-*- coding: ISO-8859-1 -*- # pysqlite2/test/hooks.py: tests for various SQLite-specific hooks # -# Copyright (C) 2006 Gerhard Häring +# Copyright (C) 2006-2007 Gerhard Häring # # This file is part of pysqlite. # @@ -21,7 +21,7 @@ # misrepresented as being the original software. # 3. This notice may not be removed or altered from any source distribution. -import unittest +import os, unittest import sqlite3 as sqlite class CollationTests(unittest.TestCase): @@ -105,9 +105,80 @@ class CollationTests(unittest.TestCase): if not e.args[0].startswith("no such collation sequence"): self.fail("wrong OperationalError raised") +class ProgressTests(unittest.TestCase): + def CheckProgressHandlerUsed(self): + """ + Test that the progress handler is invoked once it is set. + """ + con = sqlite.connect(":memory:") + progress_calls = [] + def progress(): + progress_calls.append(None) + return 0 + con.set_progress_handler(progress, 1) + con.execute(""" + create table foo(a, b) + """) + self.failUnless(progress_calls) + + + def CheckOpcodeCount(self): + """ + Test that the opcode argument is respected. + """ + con = sqlite.connect(":memory:") + progress_calls = [] + def progress(): + progress_calls.append(None) + return 0 + con.set_progress_handler(progress, 1) + curs = con.cursor() + curs.execute(""" + create table foo (a, b) + """) + first_count = len(progress_calls) + progress_calls = [] + con.set_progress_handler(progress, 2) + curs.execute(""" + create table bar (a, b) + """) + second_count = len(progress_calls) + self.failUnless(first_count > second_count) + + def CheckCancelOperation(self): + """ + Test that returning a non-zero value stops the operation in progress. + """ + con = sqlite.connect(":memory:") + progress_calls = [] + def progress(): + progress_calls.append(None) + return 1 + con.set_progress_handler(progress, 1) + curs = con.cursor() + self.assertRaises( + sqlite.OperationalError, + curs.execute, + "create table bar (a, b)") + + def CheckClearHandler(self): + """ + Test that setting the progress handler to None clears the previously set handler. + """ + con = sqlite.connect(":memory:") + action = 0 + def progress(): + action = 1 + return 0 + con.set_progress_handler(progress, 1) + con.set_progress_handler(None, 1) + con.execute("select 1 union select 2 union select 3").fetchall() + self.failUnlessEqual(action, 0, "progress handler was not cleared") + def suite(): collation_suite = unittest.makeSuite(CollationTests, "Check") - return unittest.TestSuite((collation_suite,)) + progress_suite = unittest.makeSuite(ProgressTests, "Check") + return unittest.TestSuite((collation_suite, progress_suite)) def test(): runner = unittest.TextTestRunner() diff --git a/Lib/sqlite3/test/py25tests.py b/Lib/sqlite3/test/py25tests.py new file mode 100644 index 0000000000..bce26b907b --- /dev/null +++ b/Lib/sqlite3/test/py25tests.py @@ -0,0 +1,80 @@ +#-*- coding: ISO-8859-1 -*- +# pysqlite2/test/regression.py: pysqlite regression tests +# +# Copyright (C) 2007 Gerhard Häring +# +# This file is part of pysqlite. +# +# This software is provided 'as-is', without any express or implied +# warranty. In no event will the authors be held liable for any damages +# arising from the use of this software. +# +# Permission is granted to anyone to use this software for any purpose, +# including commercial applications, and to alter it and redistribute it +# freely, subject to the following restrictions: +# +# 1. The origin of this software must not be misrepresented; you must not +# claim that you wrote the original software. If you use this software +# in a product, an acknowledgment in the product documentation would be +# appreciated but is not required. +# 2. Altered source versions must be plainly marked as such, and must not be +# misrepresented as being the original software. +# 3. This notice may not be removed or altered from any source distribution. + +from __future__ import with_statement +import unittest +import sqlite3 as sqlite + +did_rollback = False + +class MyConnection(sqlite.Connection): + def rollback(self): + global did_rollback + did_rollback = True + sqlite.Connection.rollback(self) + +class ContextTests(unittest.TestCase): + def setUp(self): + global did_rollback + self.con = sqlite.connect(":memory:", factory=MyConnection) + self.con.execute("create table test(c unique)") + did_rollback = False + + def tearDown(self): + self.con.close() + + def CheckContextManager(self): + """Can the connection be used as a context manager at all?""" + with self.con: + pass + + def CheckContextManagerCommit(self): + """Is a commit called in the context manager?""" + with self.con: + self.con.execute("insert into test(c) values ('foo')") + self.con.rollback() + count = self.con.execute("select count(*) from test").fetchone()[0] + self.failUnlessEqual(count, 1) + + def CheckContextManagerRollback(self): + """Is a rollback called in the context manager?""" + global did_rollback + self.failUnlessEqual(did_rollback, False) + try: + with self.con: + self.con.execute("insert into test(c) values (4)") + self.con.execute("insert into test(c) values (4)") + except sqlite.IntegrityError: + pass + self.failUnlessEqual(did_rollback, True) + +def suite(): + ctx_suite = unittest.makeSuite(ContextTests, "Check") + return unittest.TestSuite((ctx_suite,)) + +def test(): + runner = unittest.TextTestRunner() + runner.run(suite()) + +if __name__ == "__main__": + test() diff --git a/Lib/sqlite3/test/regression.py b/Lib/sqlite3/test/regression.py index addedb1771..45eae90554 100644 --- a/Lib/sqlite3/test/regression.py +++ b/Lib/sqlite3/test/regression.py @@ -1,7 +1,7 @@ #-*- coding: ISO-8859-1 -*- # pysqlite2/test/regression.py: pysqlite regression tests # -# Copyright (C) 2006 Gerhard Häring +# Copyright (C) 2006-2007 Gerhard Häring # # This file is part of pysqlite. # @@ -21,6 +21,7 @@ # misrepresented as being the original software. # 3. This notice may not be removed or altered from any source distribution. +import datetime import unittest import sqlite3 as sqlite @@ -79,6 +80,79 @@ class RegressionTests(unittest.TestCase): cur.fetchone() cur.fetchone() + def CheckStatementFinalizationOnCloseDb(self): + # pysqlite versions <= 2.3.3 only finalized statements in the statement + # cache when closing the database. statements that were still + # referenced in cursors weren't closed an could provoke " + # "OperationalError: Unable to close due to unfinalised statements". + con = sqlite.connect(":memory:") + cursors = [] + # default statement cache size is 100 + for i in range(105): + cur = con.cursor() + cursors.append(cur) + cur.execute("select 1 x union select " + str(i)) + con.close() + + def CheckOnConflictRollback(self): + if sqlite.sqlite_version_info < (3, 2, 2): + return + con = sqlite.connect(":memory:") + con.execute("create table foo(x, unique(x) on conflict rollback)") + con.execute("insert into foo(x) values (1)") + try: + con.execute("insert into foo(x) values (1)") + except sqlite.DatabaseError: + pass + con.execute("insert into foo(x) values (2)") + try: + con.commit() + except sqlite.OperationalError: + self.fail("pysqlite knew nothing about the implicit ROLLBACK") + + def CheckWorkaroundForBuggySqliteTransferBindings(self): + """ + pysqlite would crash with older SQLite versions unless + a workaround is implemented. + """ + self.con.execute("create table if not exists foo(bar)") + self.con.execute("create table if not exists foo(bar)") + + def CheckEmptyStatement(self): + """ + pysqlite used to segfault with SQLite versions 3.5.x. These return NULL + for "no-operation" statements + """ + self.con.execute("") + + def CheckUnicodeConnect(self): + """ + With pysqlite 2.4.0 you needed to use a string or a APSW connection + object for opening database connections. + + Formerly, both bytestrings and unicode strings used to work. + + Let's make sure unicode strings work in the future. + """ + con = sqlite.connect(u":memory:") + con.close() + + def CheckTypeMapUsage(self): + """ + pysqlite until 2.4.1 did not rebuild the row_cast_map when recompiling + a statement. This test exhibits the problem. + """ + SELECT = "select * from foo" + con = sqlite.connect(":memory:",detect_types=sqlite.PARSE_DECLTYPES) + con.execute("create table foo(bar timestamp)") + con.execute("insert into foo(bar) values (?)", (datetime.datetime.now(),)) + con.execute(SELECT) + con.execute("drop table foo") + con.execute("create table foo(bar integer)") + con.execute("insert into foo(bar) values (5)") + con.execute(SELECT) + + def suite(): regression_suite = unittest.makeSuite(RegressionTests, "Check") return unittest.TestSuite((regression_suite,)) diff --git a/Lib/sqlite3/test/transactions.py b/Lib/sqlite3/test/transactions.py index 1f0b19aa9b..14cae25001 100644 --- a/Lib/sqlite3/test/transactions.py +++ b/Lib/sqlite3/test/transactions.py @@ -1,7 +1,7 @@ #-*- coding: ISO-8859-1 -*- # pysqlite2/test/transactions.py: tests transactions # -# Copyright (C) 2005 Gerhard Häring +# Copyright (C) 2005-2007 Gerhard Häring # # This file is part of pysqlite. # @@ -21,6 +21,7 @@ # misrepresented as being the original software. # 3. This notice may not be removed or altered from any source distribution. +import sys import os, unittest import sqlite3 as sqlite @@ -119,6 +120,23 @@ class TransactionTests(unittest.TestCase): except: self.fail("should have raised an OperationalError") + def CheckLocking(self): + """ + This tests the improved concurrency with pysqlite 2.3.4. You needed + to roll back con2 before you could commit con1. + """ + self.cur1.execute("create table test(i)") + self.cur1.execute("insert into test(i) values (5)") + try: + self.cur2.execute("insert into test(i) values (5)") + self.fail("should have raised an OperationalError") + except sqlite.OperationalError: + pass + except: + self.fail("should have raised an OperationalError") + # NO self.con2.rollback() HERE!!! + self.con1.commit() + class SpecialCommandTests(unittest.TestCase): def setUp(self): self.con = sqlite.connect(":memory:") diff --git a/Lib/sqlite3/test/types.py b/Lib/sqlite3/test/types.py index 3cc9affc5d..197040172c 100644 --- a/Lib/sqlite3/test/types.py +++ b/Lib/sqlite3/test/types.py @@ -1,7 +1,7 @@ #-*- coding: ISO-8859-1 -*- # pysqlite2/test/types.py: tests for type conversion and detection # -# Copyright (C) 2005 Gerhard Häring +# Copyright (C) 2005-2007 Gerhard Häring # # This file is part of pysqlite. # @@ -21,7 +21,7 @@ # misrepresented as being the original software. # 3. This notice may not be removed or altered from any source distribution. -import bz2, datetime +import zlib, datetime import unittest import sqlite3 as sqlite @@ -287,7 +287,7 @@ class ObjectAdaptationTests(unittest.TestCase): class BinaryConverterTests(unittest.TestCase): def convert(s): - return bz2.decompress(s) + return zlib.decompress(s) convert = staticmethod(convert) def setUp(self): @@ -299,7 +299,7 @@ class BinaryConverterTests(unittest.TestCase): def CheckBinaryInputForConverter(self): testdata = "abcdefg" * 10 - result = self.con.execute('select ? as "x [bin]"', (buffer(bz2.compress(testdata)),)).fetchone()[0] + result = self.con.execute('select ? as "x [bin]"', (buffer(zlib.compress(testdata)),)).fetchone()[0] self.failUnlessEqual(testdata, result) class DateTimeTests(unittest.TestCase): @@ -331,7 +331,8 @@ class DateTimeTests(unittest.TestCase): if sqlite.sqlite_version_info < (3, 1): return - now = datetime.datetime.utcnow() + # SQLite's current_timestamp uses UTC time, while datetime.datetime.now() uses local time. + now = datetime.datetime.now() self.cur.execute("insert into test(ts) values (current_timestamp)") self.cur.execute("select ts from test") ts = self.cur.fetchone()[0] diff --git a/Lib/test/test_sqlite.py b/Lib/test/test_sqlite.py index c1523e11ba..3566f316b1 100644 --- a/Lib/test/test_sqlite.py +++ b/Lib/test/test_sqlite.py @@ -4,13 +4,13 @@ try: import _sqlite3 except ImportError: raise TestSkipped('no sqlite available') -from sqlite3.test import (dbapi, types, userfunctions, +from sqlite3.test import (dbapi, types, userfunctions, py25tests, factory, transactions, hooks, regression) def test_main(): run_unittest(dbapi.suite(), types.suite(), userfunctions.suite(), - factory.suite(), transactions.suite(), hooks.suite(), - regression.suite()) + py25tests.suite(), factory.suite(), transactions.suite(), + hooks.suite(), regression.suite()) if __name__ == "__main__": test_main() diff --git a/Modules/_sqlite/connection.c b/Modules/_sqlite/connection.c index f65748aa8c..1ce275c2e2 100644 --- a/Modules/_sqlite/connection.c +++ b/Modules/_sqlite/connection.c @@ -1,6 +1,6 @@ /* connection.c - the connection type * - * Copyright (C) 2004-2006 Gerhard Häring + * Copyright (C) 2004-2007 Gerhard Häring * * This file is part of pysqlite. * @@ -32,6 +32,9 @@ #include "pythread.h" +#define ACTION_FINALIZE 1 +#define ACTION_RESET 2 + static int pysqlite_connection_set_isolation_level(pysqlite_Connection* self, PyObject* isolation_level); @@ -51,7 +54,7 @@ int pysqlite_connection_init(pysqlite_Connection* self, PyObject* args, PyObject { static char *kwlist[] = {"database", "timeout", "detect_types", "isolation_level", "check_same_thread", "factory", "cached_statements", NULL, NULL}; - char* database; + PyObject* database; int detect_types = 0; PyObject* isolation_level = NULL; PyObject* factory = NULL; @@ -59,11 +62,15 @@ int pysqlite_connection_init(pysqlite_Connection* self, PyObject* args, PyObject int cached_statements = 100; double timeout = 5.0; int rc; + PyObject* class_attr = NULL; + PyObject* class_attr_str = NULL; + int is_apsw_connection = 0; + PyObject* database_utf8; - if (!PyArg_ParseTupleAndKeywords(args, kwargs, "s|diOiOi", kwlist, + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O|diOiOi", kwlist, &database, &timeout, &detect_types, &isolation_level, &check_same_thread, &factory, &cached_statements)) { - return -1; + return -1; } self->begin_statement = NULL; @@ -77,13 +84,53 @@ int pysqlite_connection_init(pysqlite_Connection* self, PyObject* args, PyObject Py_INCREF(&PyUnicode_Type); self->text_factory = (PyObject*)&PyUnicode_Type; - Py_BEGIN_ALLOW_THREADS - rc = sqlite3_open(database, &self->db); - Py_END_ALLOW_THREADS + if (PyString_Check(database) || PyUnicode_Check(database)) { + if (PyString_Check(database)) { + database_utf8 = database; + Py_INCREF(database_utf8); + } else { + database_utf8 = PyUnicode_AsUTF8String(database); + if (!database_utf8) { + return -1; + } + } - if (rc != SQLITE_OK) { - _pysqlite_seterror(self->db); - return -1; + Py_BEGIN_ALLOW_THREADS + rc = sqlite3_open(PyString_AsString(database_utf8), &self->db); + Py_END_ALLOW_THREADS + + Py_DECREF(database_utf8); + + if (rc != SQLITE_OK) { + _pysqlite_seterror(self->db, NULL); + return -1; + } + } else { + /* Create a pysqlite connection from a APSW connection */ + class_attr = PyObject_GetAttrString(database, "__class__"); + if (class_attr) { + class_attr_str = PyObject_Str(class_attr); + if (class_attr_str) { + if (strcmp(PyString_AsString(class_attr_str), "") == 0) { + /* In the APSW Connection object, the first entry after + * PyObject_HEAD is the sqlite3* we want to get hold of. + * Luckily, this is the same layout as we have in our + * pysqlite_Connection */ + self->db = ((pysqlite_Connection*)database)->db; + + Py_INCREF(database); + self->apsw_connection = database; + is_apsw_connection = 1; + } + } + } + Py_XDECREF(class_attr_str); + Py_XDECREF(class_attr); + + if (!is_apsw_connection) { + PyErr_SetString(PyExc_ValueError, "database parameter must be string or APSW Connection object"); + return -1; + } } if (!isolation_level) { @@ -169,7 +216,8 @@ void pysqlite_flush_statement_cache(pysqlite_Connection* self) self->statement_cache->decref_factory = 0; } -void pysqlite_reset_all_statements(pysqlite_Connection* self) +/* action in (ACTION_RESET, ACTION_FINALIZE) */ +void pysqlite_do_all_statements(pysqlite_Connection* self, int action) { int i; PyObject* weakref; @@ -179,13 +227,19 @@ void pysqlite_reset_all_statements(pysqlite_Connection* self) weakref = PyList_GetItem(self->statements, i); statement = PyWeakref_GetObject(weakref); if (statement != Py_None) { - (void)pysqlite_statement_reset((pysqlite_Statement*)statement); + if (action == ACTION_RESET) { + (void)pysqlite_statement_reset((pysqlite_Statement*)statement); + } else { + (void)pysqlite_statement_finalize((pysqlite_Statement*)statement); + } } } } void pysqlite_connection_dealloc(pysqlite_Connection* self) { + PyObject* ret = NULL; + Py_XDECREF(self->statement_cache); /* Clean up if user has not called .close() explicitly. */ @@ -193,6 +247,10 @@ void pysqlite_connection_dealloc(pysqlite_Connection* self) Py_BEGIN_ALLOW_THREADS sqlite3_close(self->db); Py_END_ALLOW_THREADS + } else if (self->apsw_connection) { + ret = PyObject_CallMethod(self->apsw_connection, "close", ""); + Py_XDECREF(ret); + Py_XDECREF(self->apsw_connection); } if (self->begin_statement) { @@ -205,7 +263,7 @@ void pysqlite_connection_dealloc(pysqlite_Connection* self) Py_XDECREF(self->collations); Py_XDECREF(self->statements); - Py_TYPE(self)->tp_free((PyObject*)self); + self->ob_type->tp_free((PyObject*)self); } PyObject* pysqlite_connection_cursor(pysqlite_Connection* self, PyObject* args, PyObject* kwargs) @@ -241,24 +299,33 @@ PyObject* pysqlite_connection_cursor(pysqlite_Connection* self, PyObject* args, PyObject* pysqlite_connection_close(pysqlite_Connection* self, PyObject* args) { + PyObject* ret; int rc; if (!pysqlite_check_thread(self)) { return NULL; } - pysqlite_flush_statement_cache(self); + pysqlite_do_all_statements(self, ACTION_FINALIZE); if (self->db) { - Py_BEGIN_ALLOW_THREADS - rc = sqlite3_close(self->db); - Py_END_ALLOW_THREADS - - if (rc != SQLITE_OK) { - _pysqlite_seterror(self->db); - return NULL; - } else { + if (self->apsw_connection) { + ret = PyObject_CallMethod(self->apsw_connection, "close", ""); + Py_XDECREF(ret); + Py_XDECREF(self->apsw_connection); + self->apsw_connection = NULL; self->db = NULL; + } else { + Py_BEGIN_ALLOW_THREADS + rc = sqlite3_close(self->db); + Py_END_ALLOW_THREADS + + if (rc != SQLITE_OK) { + _pysqlite_seterror(self->db, NULL); + return NULL; + } else { + self->db = NULL; + } } } @@ -292,7 +359,7 @@ PyObject* _pysqlite_connection_begin(pysqlite_Connection* self) Py_END_ALLOW_THREADS if (rc != SQLITE_OK) { - _pysqlite_seterror(self->db); + _pysqlite_seterror(self->db, statement); goto error; } @@ -300,7 +367,7 @@ PyObject* _pysqlite_connection_begin(pysqlite_Connection* self) if (rc == SQLITE_DONE) { self->inTransaction = 1; } else { - _pysqlite_seterror(self->db); + _pysqlite_seterror(self->db, statement); } Py_BEGIN_ALLOW_THREADS @@ -308,7 +375,7 @@ PyObject* _pysqlite_connection_begin(pysqlite_Connection* self) Py_END_ALLOW_THREADS if (rc != SQLITE_OK && !PyErr_Occurred()) { - _pysqlite_seterror(self->db); + _pysqlite_seterror(self->db, NULL); } error: @@ -335,7 +402,7 @@ PyObject* pysqlite_connection_commit(pysqlite_Connection* self, PyObject* args) rc = sqlite3_prepare(self->db, "COMMIT", -1, &statement, &tail); Py_END_ALLOW_THREADS if (rc != SQLITE_OK) { - _pysqlite_seterror(self->db); + _pysqlite_seterror(self->db, NULL); goto error; } @@ -343,14 +410,14 @@ PyObject* pysqlite_connection_commit(pysqlite_Connection* self, PyObject* args) if (rc == SQLITE_DONE) { self->inTransaction = 0; } else { - _pysqlite_seterror(self->db); + _pysqlite_seterror(self->db, statement); } Py_BEGIN_ALLOW_THREADS rc = sqlite3_finalize(statement); Py_END_ALLOW_THREADS if (rc != SQLITE_OK && !PyErr_Occurred()) { - _pysqlite_seterror(self->db); + _pysqlite_seterror(self->db, NULL); } } @@ -375,13 +442,13 @@ PyObject* pysqlite_connection_rollback(pysqlite_Connection* self, PyObject* args } if (self->inTransaction) { - pysqlite_reset_all_statements(self); + pysqlite_do_all_statements(self, ACTION_RESET); Py_BEGIN_ALLOW_THREADS rc = sqlite3_prepare(self->db, "ROLLBACK", -1, &statement, &tail); Py_END_ALLOW_THREADS if (rc != SQLITE_OK) { - _pysqlite_seterror(self->db); + _pysqlite_seterror(self->db, NULL); goto error; } @@ -389,14 +456,14 @@ PyObject* pysqlite_connection_rollback(pysqlite_Connection* self, PyObject* args if (rc == SQLITE_DONE) { self->inTransaction = 0; } else { - _pysqlite_seterror(self->db); + _pysqlite_seterror(self->db, statement); } Py_BEGIN_ALLOW_THREADS rc = sqlite3_finalize(statement); Py_END_ALLOW_THREADS if (rc != SQLITE_OK && !PyErr_Occurred()) { - _pysqlite_seterror(self->db); + _pysqlite_seterror(self->db, NULL); } } @@ -762,6 +829,33 @@ static int _authorizer_callback(void* user_arg, int action, const char* arg1, co return rc; } +static int _progress_handler(void* user_arg) +{ + int rc; + PyObject *ret; + PyGILState_STATE gilstate; + + gilstate = PyGILState_Ensure(); + ret = PyObject_CallFunction((PyObject*)user_arg, ""); + + if (!ret) { + if (_enable_callback_tracebacks) { + PyErr_Print(); + } else { + PyErr_Clear(); + } + + /* abort query if error occured */ + rc = 1; + } else { + rc = (int)PyObject_IsTrue(ret); + } + + Py_DECREF(ret); + PyGILState_Release(gilstate); + return rc; +} + PyObject* pysqlite_connection_set_authorizer(pysqlite_Connection* self, PyObject* args, PyObject* kwargs) { PyObject* authorizer_cb; @@ -787,6 +881,30 @@ PyObject* pysqlite_connection_set_authorizer(pysqlite_Connection* self, PyObject } } +PyObject* pysqlite_connection_set_progress_handler(pysqlite_Connection* self, PyObject* args, PyObject* kwargs) +{ + PyObject* progress_handler; + int n; + + static char *kwlist[] = { "progress_handler", "n", NULL }; + + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "Oi:set_progress_handler", + kwlist, &progress_handler, &n)) { + return NULL; + } + + if (progress_handler == Py_None) { + /* None clears the progress handler previously set */ + sqlite3_progress_handler(self->db, 0, 0, (void*)0); + } else { + sqlite3_progress_handler(self->db, n, _progress_handler, progress_handler); + PyDict_SetItem(self->function_pinboard, progress_handler, Py_None); + } + + Py_INCREF(Py_None); + return Py_None; +} + int pysqlite_check_thread(pysqlite_Connection* self) { if (self->check_same_thread) { @@ -892,7 +1010,8 @@ PyObject* pysqlite_connection_call(pysqlite_Connection* self, PyObject* args, Py } else if (rc == PYSQLITE_SQL_WRONG_TYPE) { PyErr_SetString(pysqlite_Warning, "SQL is of wrong type. Must be string or unicode."); } else { - _pysqlite_seterror(self->db); + (void)pysqlite_statement_reset(statement); + _pysqlite_seterror(self->db, NULL); } Py_DECREF(statement); @@ -1134,7 +1253,7 @@ pysqlite_connection_create_collation(pysqlite_Connection* self, PyObject* args) (callable != Py_None) ? pysqlite_collation_callback : NULL); if (rc != SQLITE_OK) { PyDict_DelItem(self->collations, uppercase_name); - _pysqlite_seterror(self->db); + _pysqlite_seterror(self->db, NULL); goto finally; } @@ -1151,6 +1270,44 @@ finally: return retval; } +/* Called when the connection is used as a context manager. Returns itself as a + * convenience to the caller. */ +static PyObject * +pysqlite_connection_enter(pysqlite_Connection* self, PyObject* args) +{ + Py_INCREF(self); + return (PyObject*)self; +} + +/** Called when the connection is used as a context manager. If there was any + * exception, a rollback takes place; otherwise we commit. */ +static PyObject * +pysqlite_connection_exit(pysqlite_Connection* self, PyObject* args) +{ + PyObject* exc_type, *exc_value, *exc_tb; + char* method_name; + PyObject* result; + + if (!PyArg_ParseTuple(args, "OOO", &exc_type, &exc_value, &exc_tb)) { + return NULL; + } + + if (exc_type == Py_None && exc_value == Py_None && exc_tb == Py_None) { + method_name = "commit"; + } else { + method_name = "rollback"; + } + + result = PyObject_CallMethod((PyObject*)self, method_name, ""); + if (!result) { + return NULL; + } + Py_DECREF(result); + + Py_INCREF(Py_False); + return Py_False; +} + static char connection_doc[] = PyDoc_STR("SQLite database connection object."); @@ -1175,6 +1332,8 @@ static PyMethodDef connection_methods[] = { PyDoc_STR("Creates a new aggregate. Non-standard.")}, {"set_authorizer", (PyCFunction)pysqlite_connection_set_authorizer, METH_VARARGS|METH_KEYWORDS, PyDoc_STR("Sets authorizer callback. Non-standard.")}, + {"set_progress_handler", (PyCFunction)pysqlite_connection_set_progress_handler, METH_VARARGS|METH_KEYWORDS, + PyDoc_STR("Sets progress handler callback. Non-standard.")}, {"execute", (PyCFunction)pysqlite_connection_execute, METH_VARARGS, PyDoc_STR("Executes a SQL statement. Non-standard.")}, {"executemany", (PyCFunction)pysqlite_connection_executemany, METH_VARARGS, @@ -1185,6 +1344,10 @@ static PyMethodDef connection_methods[] = { PyDoc_STR("Creates a collation function. Non-standard.")}, {"interrupt", (PyCFunction)pysqlite_connection_interrupt, METH_NOARGS, PyDoc_STR("Abort any pending database operation. Non-standard.")}, + {"__enter__", (PyCFunction)pysqlite_connection_enter, METH_NOARGS, + PyDoc_STR("For context manager. Non-standard.")}, + {"__exit__", (PyCFunction)pysqlite_connection_exit, METH_VARARGS, + PyDoc_STR("For context manager. Non-standard.")}, {NULL, NULL} }; diff --git a/Modules/_sqlite/connection.h b/Modules/_sqlite/connection.h index 21fcd2a7e8..3b1c632db8 100644 --- a/Modules/_sqlite/connection.h +++ b/Modules/_sqlite/connection.h @@ -1,6 +1,6 @@ /* connection.h - definitions for the connection type * - * Copyright (C) 2004-2006 Gerhard Häring + * Copyright (C) 2004-2007 Gerhard Häring * * This file is part of pysqlite. * @@ -95,6 +95,11 @@ typedef struct /* a dictionary of registered collation name => collation callable mappings */ PyObject* collations; + /* if our connection was created from a APSW connection, we keep a + * reference to the APSW connection around and get rid of it in our + * destructor */ + PyObject* apsw_connection; + /* Exception objects */ PyObject* Warning; PyObject* Error; diff --git a/Modules/_sqlite/cursor.c b/Modules/_sqlite/cursor.c index 875d55b883..566e4ff0da 100644 --- a/Modules/_sqlite/cursor.c +++ b/Modules/_sqlite/cursor.c @@ -1,6 +1,6 @@ /* cursor.c - the cursor type * - * Copyright (C) 2004-2006 Gerhard Häring + * Copyright (C) 2004-2007 Gerhard Häring * * This file is part of pysqlite. * @@ -80,7 +80,7 @@ int pysqlite_cursor_init(pysqlite_Cursor* self, PyObject* args, PyObject* kwargs if (!PyArg_ParseTuple(args, "O!", &pysqlite_ConnectionType, &connection)) { - return -1; + return -1; } Py_INCREF(connection); @@ -435,7 +435,7 @@ PyObject* _pysqlite_query_execute(pysqlite_Cursor* self, int multiple, PyObject* if (multiple) { /* executemany() */ if (!PyArg_ParseTuple(args, "OO", &operation, &second_argument)) { - return NULL; + return NULL; } if (!PyString_Check(operation) && !PyUnicode_Check(operation)) { @@ -457,7 +457,7 @@ PyObject* _pysqlite_query_execute(pysqlite_Cursor* self, int multiple, PyObject* } else { /* execute() */ if (!PyArg_ParseTuple(args, "O|O", &operation, &second_argument)) { - return NULL; + return NULL; } if (!PyString_Check(operation) && !PyUnicode_Check(operation)) { @@ -506,16 +506,47 @@ PyObject* _pysqlite_query_execute(pysqlite_Cursor* self, int multiple, PyObject* operation_cstr = PyString_AsString(operation_bytestr); } - /* reset description and rowcount */ + /* reset description */ Py_DECREF(self->description); Py_INCREF(Py_None); self->description = Py_None; - Py_DECREF(self->rowcount); - self->rowcount = PyInt_FromLong(-1L); - if (!self->rowcount) { + func_args = PyTuple_New(1); + if (!func_args) { goto error; } + Py_INCREF(operation); + if (PyTuple_SetItem(func_args, 0, operation) != 0) { + goto error; + } + + if (self->statement) { + (void)pysqlite_statement_reset(self->statement); + Py_DECREF(self->statement); + } + + self->statement = (pysqlite_Statement*)pysqlite_cache_get(self->connection->statement_cache, func_args); + Py_DECREF(func_args); + + if (!self->statement) { + goto error; + } + + if (self->statement->in_use) { + Py_DECREF(self->statement); + self->statement = PyObject_New(pysqlite_Statement, &pysqlite_StatementType); + if (!self->statement) { + goto error; + } + rc = pysqlite_statement_create(self->statement, self->connection, operation); + if (rc != SQLITE_OK) { + self->statement = 0; + goto error; + } + } + + pysqlite_statement_reset(self->statement); + pysqlite_statement_mark_dirty(self->statement); statement_type = detect_statement_type(operation_cstr); if (self->connection->begin_statement) { @@ -553,43 +584,6 @@ PyObject* _pysqlite_query_execute(pysqlite_Cursor* self, int multiple, PyObject* } } - func_args = PyTuple_New(1); - if (!func_args) { - goto error; - } - Py_INCREF(operation); - if (PyTuple_SetItem(func_args, 0, operation) != 0) { - goto error; - } - - if (self->statement) { - (void)pysqlite_statement_reset(self->statement); - Py_DECREF(self->statement); - } - - self->statement = (pysqlite_Statement*)pysqlite_cache_get(self->connection->statement_cache, func_args); - Py_DECREF(func_args); - - if (!self->statement) { - goto error; - } - - if (self->statement->in_use) { - Py_DECREF(self->statement); - self->statement = PyObject_New(pysqlite_Statement, &pysqlite_StatementType); - if (!self->statement) { - goto error; - } - rc = pysqlite_statement_create(self->statement, self->connection, operation); - if (rc != SQLITE_OK) { - self->statement = 0; - goto error; - } - } - - pysqlite_statement_reset(self->statement); - pysqlite_statement_mark_dirty(self->statement); - while (1) { parameters = PyIter_Next(parameters_iter); if (!parameters) { @@ -603,11 +597,6 @@ PyObject* _pysqlite_query_execute(pysqlite_Cursor* self, int multiple, PyObject* goto error; } - if (pysqlite_build_row_cast_map(self) != 0) { - PyErr_SetString(pysqlite_OperationalError, "Error while building row_cast_map"); - goto error; - } - /* Keep trying the SQL statement until the schema stops changing. */ while (1) { /* Actually execute the SQL statement. */ @@ -626,7 +615,8 @@ PyObject* _pysqlite_query_execute(pysqlite_Cursor* self, int multiple, PyObject* continue; } else { /* If the database gave us an error, promote it to Python. */ - _pysqlite_seterror(self->connection->db); + (void)pysqlite_statement_reset(self->statement); + _pysqlite_seterror(self->connection->db, NULL); goto error; } } else { @@ -638,17 +628,23 @@ PyObject* _pysqlite_query_execute(pysqlite_Cursor* self, int multiple, PyObject* PyErr_Clear(); } } - _pysqlite_seterror(self->connection->db); + (void)pysqlite_statement_reset(self->statement); + _pysqlite_seterror(self->connection->db, NULL); goto error; } } - if (rc == SQLITE_ROW || (rc == SQLITE_DONE && statement_type == STATEMENT_SELECT)) { - Py_BEGIN_ALLOW_THREADS - numcols = sqlite3_column_count(self->statement->st); - Py_END_ALLOW_THREADS + if (pysqlite_build_row_cast_map(self) != 0) { + PyErr_SetString(pysqlite_OperationalError, "Error while building row_cast_map"); + goto error; + } + if (rc == SQLITE_ROW || (rc == SQLITE_DONE && statement_type == STATEMENT_SELECT)) { if (self->description == Py_None) { + Py_BEGIN_ALLOW_THREADS + numcols = sqlite3_column_count(self->statement->st); + Py_END_ALLOW_THREADS + Py_DECREF(self->description); self->description = PyTuple_New(numcols); if (!self->description) { @@ -689,15 +685,11 @@ PyObject* _pysqlite_query_execute(pysqlite_Cursor* self, int multiple, PyObject* case STATEMENT_DELETE: case STATEMENT_INSERT: case STATEMENT_REPLACE: - Py_BEGIN_ALLOW_THREADS rowcount += (long)sqlite3_changes(self->connection->db); - Py_END_ALLOW_THREADS - Py_DECREF(self->rowcount); - self->rowcount = PyInt_FromLong(rowcount); } Py_DECREF(self->lastrowid); - if (statement_type == STATEMENT_INSERT) { + if (!multiple && statement_type == STATEMENT_INSERT) { Py_BEGIN_ALLOW_THREADS lastrowid = sqlite3_last_insert_rowid(self->connection->db); Py_END_ALLOW_THREADS @@ -714,14 +706,27 @@ PyObject* _pysqlite_query_execute(pysqlite_Cursor* self, int multiple, PyObject* } error: + /* just to be sure (implicit ROLLBACKs with ON CONFLICT ROLLBACK/OR + * ROLLBACK could have happened */ + #ifdef SQLITE_VERSION_NUMBER + #if SQLITE_VERSION_NUMBER >= 3002002 + self->connection->inTransaction = !sqlite3_get_autocommit(self->connection->db); + #endif + #endif + Py_XDECREF(operation_bytestr); Py_XDECREF(parameters); Py_XDECREF(parameters_iter); Py_XDECREF(parameters_list); if (PyErr_Occurred()) { + Py_DECREF(self->rowcount); + self->rowcount = PyInt_FromLong(-1L); return NULL; } else { + Py_DECREF(self->rowcount); + self->rowcount = PyInt_FromLong(rowcount); + Py_INCREF(self); return (PyObject*)self; } @@ -748,7 +753,7 @@ PyObject* pysqlite_cursor_executescript(pysqlite_Cursor* self, PyObject* args) int statement_completed = 0; if (!PyArg_ParseTuple(args, "O", &script_obj)) { - return NULL; + return NULL; } if (!pysqlite_check_thread(self->connection) || !pysqlite_check_connection(self->connection)) { @@ -788,7 +793,7 @@ PyObject* pysqlite_cursor_executescript(pysqlite_Cursor* self, PyObject* args) &statement, &script_cstr); if (rc != SQLITE_OK) { - _pysqlite_seterror(self->connection->db); + _pysqlite_seterror(self->connection->db, NULL); goto error; } @@ -796,17 +801,18 @@ PyObject* pysqlite_cursor_executescript(pysqlite_Cursor* self, PyObject* args) rc = SQLITE_ROW; while (rc == SQLITE_ROW) { rc = _sqlite_step_with_busyhandler(statement, self->connection); + /* TODO: we probably need more error handling here */ } if (rc != SQLITE_DONE) { (void)sqlite3_finalize(statement); - _pysqlite_seterror(self->connection->db); + _pysqlite_seterror(self->connection->db, NULL); goto error; } rc = sqlite3_finalize(statement); if (rc != SQLITE_OK) { - _pysqlite_seterror(self->connection->db); + _pysqlite_seterror(self->connection->db, NULL); goto error; } } @@ -864,8 +870,9 @@ PyObject* pysqlite_cursor_iternext(pysqlite_Cursor *self) if (self->statement) { rc = _sqlite_step_with_busyhandler(self->statement->st, self->connection); if (rc != SQLITE_DONE && rc != SQLITE_ROW) { + (void)pysqlite_statement_reset(self->statement); Py_DECREF(next_row); - _pysqlite_seterror(self->connection->db); + _pysqlite_seterror(self->connection->db, NULL); return NULL; } @@ -890,15 +897,17 @@ PyObject* pysqlite_cursor_fetchone(pysqlite_Cursor* self, PyObject* args) return row; } -PyObject* pysqlite_cursor_fetchmany(pysqlite_Cursor* self, PyObject* args) +PyObject* pysqlite_cursor_fetchmany(pysqlite_Cursor* self, PyObject* args, PyObject* kwargs) { + static char *kwlist[] = {"size", NULL, NULL}; + PyObject* row; PyObject* list; int maxrows = self->arraysize; int counter = 0; - if (!PyArg_ParseTuple(args, "|i", &maxrows)) { - return NULL; + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|i:fetchmany", kwlist, &maxrows)) { + return NULL; } list = PyList_New(0); @@ -992,7 +1001,7 @@ static PyMethodDef cursor_methods[] = { PyDoc_STR("Executes a multiple SQL statements at once. Non-standard.")}, {"fetchone", (PyCFunction)pysqlite_cursor_fetchone, METH_NOARGS, PyDoc_STR("Fetches one row from the resultset.")}, - {"fetchmany", (PyCFunction)pysqlite_cursor_fetchmany, METH_VARARGS, + {"fetchmany", (PyCFunction)pysqlite_cursor_fetchmany, METH_VARARGS|METH_KEYWORDS, PyDoc_STR("Fetches several rows from the resultset.")}, {"fetchall", (PyCFunction)pysqlite_cursor_fetchall, METH_NOARGS, PyDoc_STR("Fetches all rows from the resultset.")}, diff --git a/Modules/_sqlite/cursor.h b/Modules/_sqlite/cursor.h index 5fce64a3fb..d916ca5fc6 100644 --- a/Modules/_sqlite/cursor.h +++ b/Modules/_sqlite/cursor.h @@ -1,6 +1,6 @@ /* cursor.h - definitions for the cursor type * - * Copyright (C) 2004-2006 Gerhard Häring + * Copyright (C) 2004-2007 Gerhard Häring * * This file is part of pysqlite. * @@ -60,7 +60,7 @@ PyObject* pysqlite_cursor_executemany(pysqlite_Cursor* self, PyObject* args); PyObject* pysqlite_cursor_getiter(pysqlite_Cursor *self); PyObject* pysqlite_cursor_iternext(pysqlite_Cursor *self); PyObject* pysqlite_cursor_fetchone(pysqlite_Cursor* self, PyObject* args); -PyObject* pysqlite_cursor_fetchmany(pysqlite_Cursor* self, PyObject* args); +PyObject* pysqlite_cursor_fetchmany(pysqlite_Cursor* self, PyObject* args, PyObject* kwargs); PyObject* pysqlite_cursor_fetchall(pysqlite_Cursor* self, PyObject* args); PyObject* pysqlite_noop(pysqlite_Connection* self, PyObject* args); PyObject* pysqlite_cursor_close(pysqlite_Cursor* self, PyObject* args); diff --git a/Modules/_sqlite/microprotocols.h b/Modules/_sqlite/microprotocols.h index d84ec9397c..c911c8124d 100644 --- a/Modules/_sqlite/microprotocols.h +++ b/Modules/_sqlite/microprotocols.h @@ -28,10 +28,6 @@ #include -#ifdef __cplusplus -extern "C" { -#endif - /** adapters registry **/ extern PyObject *psyco_adapters; diff --git a/Modules/_sqlite/module.c b/Modules/_sqlite/module.c index 8844d81ad4..af7eace685 100644 --- a/Modules/_sqlite/module.c +++ b/Modules/_sqlite/module.c @@ -1,25 +1,25 @@ - /* module.c - the module itself - * - * Copyright (C) 2004-2006 Gerhard Häring - * - * This file is part of pysqlite. - * - * This software is provided 'as-is', without any express or implied - * warranty. In no event will the authors be held liable for any damages - * arising from the use of this software. - * - * Permission is granted to anyone to use this software for any purpose, - * including commercial applications, and to alter it and redistribute it - * freely, subject to the following restrictions: - * - * 1. The origin of this software must not be misrepresented; you must not - * claim that you wrote the original software. If you use this software - * in a product, an acknowledgment in the product documentation would be - * appreciated but is not required. - * 2. Altered source versions must be plainly marked as such, and must not be - * misrepresented as being the original software. - * 3. This notice may not be removed or altered from any source distribution. - */ +/* module.c - the module itself + * + * Copyright (C) 2004-2007 Gerhard Häring + * + * This file is part of pysqlite. + * + * This software is provided 'as-is', without any express or implied + * warranty. In no event will the authors be held liable for any damages + * arising from the use of this software. + * + * Permission is granted to anyone to use this software for any purpose, + * including commercial applications, and to alter it and redistribute it + * freely, subject to the following restrictions: + * + * 1. The origin of this software must not be misrepresented; you must not + * claim that you wrote the original software. If you use this software + * in a product, an acknowledgment in the product documentation would be + * appreciated but is not required. + * 2. Altered source versions must be plainly marked as such, and must not be + * misrepresented as being the original software. + * 3. This notice may not be removed or altered from any source distribution. + */ #include "connection.h" #include "statement.h" @@ -41,6 +41,7 @@ PyObject* pysqlite_Error, *pysqlite_Warning, *pysqlite_InterfaceError, *pysqlite PyObject* converters; int _enable_callback_tracebacks; +int pysqlite_BaseTypeAdapted; static PyObject* module_connect(PyObject* self, PyObject* args, PyObject* kwargs) @@ -50,7 +51,7 @@ static PyObject* module_connect(PyObject* self, PyObject* args, PyObject* * connection.c and must always be copied from there ... */ static char *kwlist[] = {"database", "timeout", "detect_types", "isolation_level", "check_same_thread", "factory", "cached_statements", NULL, NULL}; - char* database; + PyObject* database; int detect_types = 0; PyObject* isolation_level; PyObject* factory = NULL; @@ -60,7 +61,7 @@ static PyObject* module_connect(PyObject* self, PyObject* args, PyObject* PyObject* result; - if (!PyArg_ParseTupleAndKeywords(args, kwargs, "s|diOiOi", kwlist, + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O|diOiOi", kwlist, &database, &timeout, &detect_types, &isolation_level, &check_same_thread, &factory, &cached_statements)) { return NULL; @@ -133,6 +134,13 @@ static PyObject* module_register_adapter(PyObject* self, PyObject* args, PyObjec return NULL; } + /* a basic type is adapted; there's a performance optimization if that's not the case + * (99 % of all usages) */ + if (type == &PyInt_Type || type == &PyLong_Type || type == &PyFloat_Type + || type == &PyString_Type || type == &PyUnicode_Type || type == &PyBuffer_Type) { + pysqlite_BaseTypeAdapted = 1; + } + microprotocols_add(type, (PyObject*)&pysqlite_PrepareProtocolType, caster); Py_INCREF(Py_None); @@ -379,6 +387,8 @@ PyMODINIT_FUNC init_sqlite3(void) _enable_callback_tracebacks = 0; + pysqlite_BaseTypeAdapted = 0; + /* Original comment form _bsddb.c in the Python core. This is also still * needed nowadays for Python 2.3/2.4. * diff --git a/Modules/_sqlite/module.h b/Modules/_sqlite/module.h index ada6b4c5fc..b14be2a9ec 100644 --- a/Modules/_sqlite/module.h +++ b/Modules/_sqlite/module.h @@ -1,6 +1,6 @@ /* module.h - definitions for the module * - * Copyright (C) 2004-2006 Gerhard Häring + * Copyright (C) 2004-2007 Gerhard Häring * * This file is part of pysqlite. * @@ -25,7 +25,7 @@ #define PYSQLITE_MODULE_H #include "Python.h" -#define PYSQLITE_VERSION "2.3.3" +#define PYSQLITE_VERSION "2.4.1" extern PyObject* pysqlite_Error; extern PyObject* pysqlite_Warning; @@ -51,6 +51,7 @@ extern PyObject* time_sleep; extern PyObject* converters; extern int _enable_callback_tracebacks; +extern int pysqlite_BaseTypeAdapted; #define PARSE_DECLTYPES 1 #define PARSE_COLNAMES 2 diff --git a/Modules/_sqlite/statement.c b/Modules/_sqlite/statement.c index 83c07908bb..556ea01f41 100644 --- a/Modules/_sqlite/statement.c +++ b/Modules/_sqlite/statement.c @@ -1,6 +1,6 @@ /* statement.c - the statement type * - * Copyright (C) 2005-2006 Gerhard Häring + * Copyright (C) 2005-2007 Gerhard Häring * * This file is part of pysqlite. * @@ -40,6 +40,16 @@ typedef enum { NORMAL } parse_remaining_sql_state; +typedef enum { + TYPE_INT, + TYPE_LONG, + TYPE_FLOAT, + TYPE_STRING, + TYPE_UNICODE, + TYPE_BUFFER, + TYPE_UNKNOWN +} parameter_type; + int pysqlite_statement_create(pysqlite_Statement* self, pysqlite_Connection* connection, PyObject* sql) { const char* tail; @@ -97,42 +107,96 @@ int pysqlite_statement_bind_parameter(pysqlite_Statement* self, int pos, PyObjec char* string; Py_ssize_t buflen; PyObject* stringval; + parameter_type paramtype; if (parameter == Py_None) { rc = sqlite3_bind_null(self->st, pos); + goto final; + } + + if (PyInt_CheckExact(parameter)) { + paramtype = TYPE_INT; + } else if (PyLong_CheckExact(parameter)) { + paramtype = TYPE_LONG; + } else if (PyFloat_CheckExact(parameter)) { + paramtype = TYPE_FLOAT; + } else if (PyString_CheckExact(parameter)) { + paramtype = TYPE_STRING; + } else if (PyUnicode_CheckExact(parameter)) { + paramtype = TYPE_UNICODE; + } else if (PyBuffer_Check(parameter)) { + paramtype = TYPE_BUFFER; } else if (PyInt_Check(parameter)) { - longval = PyInt_AsLong(parameter); - rc = sqlite3_bind_int64(self->st, pos, (sqlite_int64)longval); -#ifdef HAVE_LONG_LONG + paramtype = TYPE_INT; } else if (PyLong_Check(parameter)) { - longlongval = PyLong_AsLongLong(parameter); - /* in the overflow error case, longlongval is -1, and an exception is set */ - rc = sqlite3_bind_int64(self->st, pos, (sqlite_int64)longlongval); -#endif + paramtype = TYPE_LONG; } else if (PyFloat_Check(parameter)) { - rc = sqlite3_bind_double(self->st, pos, PyFloat_AsDouble(parameter)); - } else if (PyBuffer_Check(parameter)) { - if (PyObject_AsCharBuffer(parameter, &buffer, &buflen) == 0) { - rc = sqlite3_bind_blob(self->st, pos, buffer, buflen, SQLITE_TRANSIENT); - } else { - PyErr_SetString(PyExc_ValueError, "could not convert BLOB to buffer"); - rc = -1; - } - } else if PyString_Check(parameter) { - string = PyString_AsString(parameter); - rc = sqlite3_bind_text(self->st, pos, string, -1, SQLITE_TRANSIENT); - } else if PyUnicode_Check(parameter) { - stringval = PyUnicode_AsUTF8String(parameter); - string = PyString_AsString(stringval); - rc = sqlite3_bind_text(self->st, pos, string, -1, SQLITE_TRANSIENT); - Py_DECREF(stringval); + paramtype = TYPE_FLOAT; + } else if (PyString_Check(parameter)) { + paramtype = TYPE_STRING; + } else if (PyUnicode_Check(parameter)) { + paramtype = TYPE_UNICODE; } else { - rc = -1; + paramtype = TYPE_UNKNOWN; } + switch (paramtype) { + case TYPE_INT: + longval = PyInt_AsLong(parameter); + rc = sqlite3_bind_int64(self->st, pos, (sqlite_int64)longval); + break; +#ifdef HAVE_LONG_LONG + case TYPE_LONG: + longlongval = PyLong_AsLongLong(parameter); + /* in the overflow error case, longlongval is -1, and an exception is set */ + rc = sqlite3_bind_int64(self->st, pos, (sqlite_int64)longlongval); + break; +#endif + case TYPE_FLOAT: + rc = sqlite3_bind_double(self->st, pos, PyFloat_AsDouble(parameter)); + break; + case TYPE_STRING: + string = PyString_AS_STRING(parameter); + rc = sqlite3_bind_text(self->st, pos, string, -1, SQLITE_TRANSIENT); + break; + case TYPE_UNICODE: + stringval = PyUnicode_AsUTF8String(parameter); + string = PyString_AsString(stringval); + rc = sqlite3_bind_text(self->st, pos, string, -1, SQLITE_TRANSIENT); + Py_DECREF(stringval); + break; + case TYPE_BUFFER: + if (PyObject_AsCharBuffer(parameter, &buffer, &buflen) == 0) { + rc = sqlite3_bind_blob(self->st, pos, buffer, buflen, SQLITE_TRANSIENT); + } else { + PyErr_SetString(PyExc_ValueError, "could not convert BLOB to buffer"); + rc = -1; + } + break; + case TYPE_UNKNOWN: + rc = -1; + } + +final: return rc; } +/* returns 0 if the object is one of Python's internal ones that don't need to be adapted */ +static int _need_adapt(PyObject* obj) +{ + if (pysqlite_BaseTypeAdapted) { + return 1; + } + + if (PyInt_CheckExact(obj) || PyLong_CheckExact(obj) + || PyFloat_CheckExact(obj) || PyString_CheckExact(obj) + || PyUnicode_CheckExact(obj) || PyBuffer_Check(obj)) { + return 0; + } else { + return 1; + } +} + void pysqlite_statement_bind_parameters(pysqlite_Statement* self, PyObject* parameters) { PyObject* current_param; @@ -147,7 +211,55 @@ void pysqlite_statement_bind_parameters(pysqlite_Statement* self, PyObject* para num_params_needed = sqlite3_bind_parameter_count(self->st); Py_END_ALLOW_THREADS - if (PyDict_Check(parameters)) { + if (PyTuple_CheckExact(parameters) || PyList_CheckExact(parameters) || (!PyDict_Check(parameters) && PySequence_Check(parameters))) { + /* parameters passed as sequence */ + if (PyTuple_CheckExact(parameters)) { + num_params = PyTuple_GET_SIZE(parameters); + } else if (PyList_CheckExact(parameters)) { + num_params = PyList_GET_SIZE(parameters); + } else { + num_params = PySequence_Size(parameters); + } + if (num_params != num_params_needed) { + PyErr_Format(pysqlite_ProgrammingError, "Incorrect number of bindings supplied. The current statement uses %d, and there are %d supplied.", + num_params_needed, num_params); + return; + } + for (i = 0; i < num_params; i++) { + if (PyTuple_CheckExact(parameters)) { + current_param = PyTuple_GET_ITEM(parameters, i); + Py_XINCREF(current_param); + } else if (PyList_CheckExact(parameters)) { + current_param = PyList_GET_ITEM(parameters, i); + Py_XINCREF(current_param); + } else { + current_param = PySequence_GetItem(parameters, i); + } + if (!current_param) { + return; + } + + if (!_need_adapt(current_param)) { + adapted = current_param; + } else { + adapted = microprotocols_adapt(current_param, (PyObject*)&pysqlite_PrepareProtocolType, NULL); + if (adapted) { + Py_DECREF(current_param); + } else { + PyErr_Clear(); + adapted = current_param; + } + } + + rc = pysqlite_statement_bind_parameter(self, i + 1, adapted); + Py_DECREF(adapted); + + if (rc != SQLITE_OK) { + PyErr_Format(pysqlite_InterfaceError, "Error binding parameter %d - probably unsupported type.", i); + return; + } + } + } else if (PyDict_Check(parameters)) { /* parameters passed as dictionary */ for (i = 1; i <= num_params_needed; i++) { Py_BEGIN_ALLOW_THREADS @@ -159,19 +271,27 @@ void pysqlite_statement_bind_parameters(pysqlite_Statement* self, PyObject* para } binding_name++; /* skip first char (the colon) */ - current_param = PyDict_GetItemString(parameters, binding_name); + if (PyDict_CheckExact(parameters)) { + current_param = PyDict_GetItemString(parameters, binding_name); + Py_XINCREF(current_param); + } else { + current_param = PyMapping_GetItemString(parameters, (char*)binding_name); + } if (!current_param) { PyErr_Format(pysqlite_ProgrammingError, "You did not supply a value for binding %d.", i); return; } - Py_INCREF(current_param); - adapted = microprotocols_adapt(current_param, (PyObject*)&pysqlite_PrepareProtocolType, NULL); - if (adapted) { - Py_DECREF(current_param); - } else { - PyErr_Clear(); + if (!_need_adapt(current_param)) { adapted = current_param; + } else { + adapted = microprotocols_adapt(current_param, (PyObject*)&pysqlite_PrepareProtocolType, NULL); + if (adapted) { + Py_DECREF(current_param); + } else { + PyErr_Clear(); + adapted = current_param; + } } rc = pysqlite_statement_bind_parameter(self, i, adapted); @@ -183,35 +303,7 @@ void pysqlite_statement_bind_parameters(pysqlite_Statement* self, PyObject* para } } } else { - /* parameters passed as sequence */ - num_params = PySequence_Length(parameters); - if (num_params != num_params_needed) { - PyErr_Format(pysqlite_ProgrammingError, "Incorrect number of bindings supplied. The current statement uses %d, and there are %d supplied.", - num_params_needed, num_params); - return; - } - for (i = 0; i < num_params; i++) { - current_param = PySequence_GetItem(parameters, i); - if (!current_param) { - return; - } - adapted = microprotocols_adapt(current_param, (PyObject*)&pysqlite_PrepareProtocolType, NULL); - - if (adapted) { - Py_DECREF(current_param); - } else { - PyErr_Clear(); - adapted = current_param; - } - - rc = pysqlite_statement_bind_parameter(self, i + 1, adapted); - Py_DECREF(adapted); - - if (rc != SQLITE_OK) { - PyErr_Format(pysqlite_InterfaceError, "Error binding parameter %d - probably unsupported type.", i); - return; - } - } + PyErr_SetString(PyExc_ValueError, "parameters are of unsupported type"); } } diff --git a/Modules/_sqlite/util.c b/Modules/_sqlite/util.c index 5e78d58846..e06c299564 100644 --- a/Modules/_sqlite/util.c +++ b/Modules/_sqlite/util.c @@ -1,6 +1,6 @@ /* util.c - various utility functions * - * Copyright (C) 2005-2006 Gerhard Häring + * Copyright (C) 2005-2007 Gerhard Häring * * This file is part of pysqlite. * @@ -45,10 +45,15 @@ int _sqlite_step_with_busyhandler(sqlite3_stmt* statement, pysqlite_Connection* * Checks the SQLite error code and sets the appropriate DB-API exception. * Returns the error code (0 means no error occurred). */ -int _pysqlite_seterror(sqlite3* db) +int _pysqlite_seterror(sqlite3* db, sqlite3_stmt* st) { int errorcode; + /* SQLite often doesn't report anything useful, unless you reset the statement first */ + if (st != NULL) { + (void)sqlite3_reset(st); + } + errorcode = sqlite3_errcode(db); switch (errorcode) diff --git a/Modules/_sqlite/util.h b/Modules/_sqlite/util.h index 969c5e52df..6c343298e2 100644 --- a/Modules/_sqlite/util.h +++ b/Modules/_sqlite/util.h @@ -34,5 +34,5 @@ int _sqlite_step_with_busyhandler(sqlite3_stmt* statement, pysqlite_Connection* * Checks the SQLite error code and sets the appropriate DB-API exception. * Returns the error code (0 means no error occurred). */ -int _pysqlite_seterror(sqlite3* db); +int _pysqlite_seterror(sqlite3* db, sqlite3_stmt* st); #endif