diff --git a/Modules/_sqlite/connection.c b/Modules/_sqlite/connection.c index 3b287b0c5abacf..e124b17782dfe5 100644 --- a/Modules/_sqlite/connection.c +++ b/Modules/_sqlite/connection.c @@ -1337,7 +1337,6 @@ pysqlite_connection_call(pysqlite_Connection *self, PyObject *args, PyObject* sql; pysqlite_Statement* statement; PyObject* weakref; - int rc; if (!pysqlite_check_thread(self) || !pysqlite_check_connection(self)) { return NULL; @@ -1351,31 +1350,11 @@ pysqlite_connection_call(pysqlite_Connection *self, PyObject *args, _pysqlite_drop_unused_statement_references(self); - statement = PyObject_GC_New(pysqlite_Statement, pysqlite_StatementType); - if (!statement) { + statement = pysqlite_statement_create(self, sql); + if (statement == NULL) { return NULL; } - statement->db = NULL; - statement->st = NULL; - statement->sql = NULL; - statement->in_use = 0; - statement->in_weakreflist = NULL; - - rc = pysqlite_statement_create(statement, self, sql); - if (rc != SQLITE_OK) { - if (rc == PYSQLITE_TOO_MUCH_SQL) { - PyErr_SetString(pysqlite_Warning, "You can only execute one statement at a time."); - } else if (rc == PYSQLITE_SQL_WRONG_TYPE) { - if (PyErr_ExceptionMatches(PyExc_TypeError)) - PyErr_SetString(pysqlite_Warning, "SQL is of wrong type. Must be string."); - } else { - (void)pysqlite_statement_reset(statement); - _pysqlite_seterror(self->db, NULL); - } - goto error; - } - weakref = PyWeakref_NewRef((PyObject*)statement, NULL); if (weakref == NULL) goto error; diff --git a/Modules/_sqlite/cursor.c b/Modules/_sqlite/cursor.c index 558b43a597000e..0335e98730a126 100644 --- a/Modules/_sqlite/cursor.c +++ b/Modules/_sqlite/cursor.c @@ -507,13 +507,8 @@ _pysqlite_query_execute(pysqlite_Cursor* self, int multiple, PyObject* operation if (self->statement->in_use) { Py_SETREF(self->statement, - PyObject_GC_New(pysqlite_Statement, pysqlite_StatementType)); - if (!self->statement) { - goto error; - } - rc = pysqlite_statement_create(self->statement, self->connection, operation); - if (rc != SQLITE_OK) { - Py_CLEAR(self->statement); + pysqlite_statement_create(self->connection, operation)); + if (self->statement == NULL) { goto error; } } diff --git a/Modules/_sqlite/statement.c b/Modules/_sqlite/statement.c index 332bf030a9b909..e6dc4fd89528d2 100644 --- a/Modules/_sqlite/statement.c +++ b/Modules/_sqlite/statement.c @@ -48,7 +48,8 @@ typedef enum { TYPE_UNKNOWN } parameter_type; -int pysqlite_statement_create(pysqlite_Statement* self, pysqlite_Connection* connection, PyObject* sql) +pysqlite_Statement * +pysqlite_statement_create(pysqlite_Connection *connection, PyObject *sql) { const char* tail; int rc; @@ -56,27 +57,36 @@ int pysqlite_statement_create(pysqlite_Statement* self, pysqlite_Connection* con Py_ssize_t sql_cstr_len; const char* p; - self->st = NULL; - self->in_use = 0; - assert(PyUnicode_Check(sql)); sql_cstr = PyUnicode_AsUTF8AndSize(sql, &sql_cstr_len); if (sql_cstr == NULL) { - rc = PYSQLITE_SQL_WRONG_TYPE; - return rc; + PyErr_Format(pysqlite_Warning, + "SQL is of wrong type ('%s'). Must be string.", + Py_TYPE(sql)->tp_name); + return NULL; } if (strlen(sql_cstr) != (size_t)sql_cstr_len) { - PyErr_SetString(PyExc_ValueError, "the query contains a null character"); - return PYSQLITE_SQL_WRONG_TYPE; + PyErr_SetString(PyExc_ValueError, + "the query contains a null character"); + return NULL; } - self->in_weakreflist = NULL; + pysqlite_Statement *self = PyObject_GC_New(pysqlite_Statement, + pysqlite_StatementType); + if (self == NULL) { + return NULL; + } + + self->db = connection->db; + self->st = NULL; self->sql = Py_NewRef(sql); + self->in_use = 0; + self->is_dml = 0; + self->in_weakreflist = NULL; /* Determine if the statement is a DML statement. SELECT is the only exception. See #9924. */ - self->is_dml = 0; for (p = sql_cstr; *p != 0; p++) { switch (*p) { case ' ': @@ -94,22 +104,33 @@ int pysqlite_statement_create(pysqlite_Statement* self, pysqlite_Connection* con } Py_BEGIN_ALLOW_THREADS - rc = sqlite3_prepare_v2(connection->db, + rc = sqlite3_prepare_v2(self->db, sql_cstr, -1, &self->st, &tail); Py_END_ALLOW_THREADS - self->db = connection->db; + PyObject_GC_Track(self); + + if (rc != SQLITE_OK) { + _pysqlite_seterror(self->db, NULL); + goto error; + } if (rc == SQLITE_OK && pysqlite_check_remaining_sql(tail)) { (void)sqlite3_finalize(self->st); self->st = NULL; - rc = PYSQLITE_TOO_MUCH_SQL; + PyErr_SetString(pysqlite_Warning, + "You can only execute one statement at a time."); + goto error; } - return rc; + return self; + +error: + Py_DECREF(self); + return NULL; } int pysqlite_statement_bind_parameter(pysqlite_Statement* self, int pos, PyObject* parameter) diff --git a/Modules/_sqlite/statement.h b/Modules/_sqlite/statement.h index 56ff7271448d1c..e8c86a0ec963fd 100644 --- a/Modules/_sqlite/statement.h +++ b/Modules/_sqlite/statement.h @@ -29,9 +29,6 @@ #include "connection.h" #include "sqlite3.h" -#define PYSQLITE_TOO_MUCH_SQL (-100) -#define PYSQLITE_SQL_WRONG_TYPE (-101) - typedef struct { PyObject_HEAD @@ -45,7 +42,7 @@ typedef struct extern PyTypeObject *pysqlite_StatementType; -int pysqlite_statement_create(pysqlite_Statement* self, pysqlite_Connection* connection, PyObject* sql); +pysqlite_Statement *pysqlite_statement_create(pysqlite_Connection *connection, PyObject *sql); int pysqlite_statement_bind_parameter(pysqlite_Statement* self, int pos, PyObject* parameter); void pysqlite_statement_bind_parameters(pysqlite_Statement* self, PyObject* parameters);