]> granicus.if.org Git - postgresql/blob - src/pl/plpython/plpy_plpymodule.c
Be more careful about Python refcounts while creating exception objects.
[postgresql] / src / pl / plpython / plpy_plpymodule.c
1 /*
2  * the plpy module
3  *
4  * src/pl/plpython/plpy_plpymodule.c
5  */
6
7 #include "postgres.h"
8
9 #include "mb/pg_wchar.h"
10 #include "utils/builtins.h"
11
12 #include "plpython.h"
13
14 #include "plpy_plpymodule.h"
15
16 #include "plpy_cursorobject.h"
17 #include "plpy_elog.h"
18 #include "plpy_planobject.h"
19 #include "plpy_resultobject.h"
20 #include "plpy_spi.h"
21 #include "plpy_subxactobject.h"
22
23
24 HTAB       *PLy_spi_exceptions = NULL;
25
26
27 static void PLy_add_exceptions(PyObject *plpy);
28 static PyObject *PLy_create_exception(char *name,
29                                          PyObject *base, PyObject *dict,
30                                          const char *modname, PyObject *mod);
31 static void PLy_generate_spi_exceptions(PyObject *mod, PyObject *base);
32
33 /* module functions */
34 static PyObject *PLy_debug(PyObject *self, PyObject *args, PyObject *kw);
35 static PyObject *PLy_log(PyObject *self, PyObject *args, PyObject *kw);
36 static PyObject *PLy_info(PyObject *self, PyObject *args, PyObject *kw);
37 static PyObject *PLy_notice(PyObject *self, PyObject *args, PyObject *kw);
38 static PyObject *PLy_warning(PyObject *self, PyObject *args, PyObject *kw);
39 static PyObject *PLy_error(PyObject *self, PyObject *args, PyObject *kw);
40 static PyObject *PLy_fatal(PyObject *self, PyObject *args, PyObject *kw);
41 static PyObject *PLy_quote_literal(PyObject *self, PyObject *args);
42 static PyObject *PLy_quote_nullable(PyObject *self, PyObject *args);
43 static PyObject *PLy_quote_ident(PyObject *self, PyObject *args);
44
45
46 /* A list of all known exceptions, generated from backend/utils/errcodes.txt */
47 typedef struct ExceptionMap
48 {
49         char       *name;
50         char       *classname;
51         int                     sqlstate;
52 } ExceptionMap;
53
54 static const ExceptionMap exception_map[] = {
55 #include "spiexceptions.h"
56         {NULL, NULL, 0}
57 };
58
59 static PyMethodDef PLy_methods[] = {
60         /*
61          * logging methods
62          */
63         {"debug", (PyCFunction) PLy_debug, METH_VARARGS | METH_KEYWORDS, NULL},
64         {"log", (PyCFunction) PLy_log, METH_VARARGS | METH_KEYWORDS, NULL},
65         {"info", (PyCFunction) PLy_info, METH_VARARGS | METH_KEYWORDS, NULL},
66         {"notice", (PyCFunction) PLy_notice, METH_VARARGS | METH_KEYWORDS, NULL},
67         {"warning", (PyCFunction) PLy_warning, METH_VARARGS | METH_KEYWORDS, NULL},
68         {"error", (PyCFunction) PLy_error, METH_VARARGS | METH_KEYWORDS, NULL},
69         {"fatal", (PyCFunction) PLy_fatal, METH_VARARGS | METH_KEYWORDS, NULL},
70
71         /*
72          * create a stored plan
73          */
74         {"prepare", PLy_spi_prepare, METH_VARARGS, NULL},
75
76         /*
77          * execute a plan or query
78          */
79         {"execute", PLy_spi_execute, METH_VARARGS, NULL},
80
81         /*
82          * escaping strings
83          */
84         {"quote_literal", PLy_quote_literal, METH_VARARGS, NULL},
85         {"quote_nullable", PLy_quote_nullable, METH_VARARGS, NULL},
86         {"quote_ident", PLy_quote_ident, METH_VARARGS, NULL},
87
88         /*
89          * create the subtransaction context manager
90          */
91         {"subtransaction", PLy_subtransaction_new, METH_NOARGS, NULL},
92
93         /*
94          * create a cursor
95          */
96         {"cursor", PLy_cursor, METH_VARARGS, NULL},
97
98         {NULL, NULL, 0, NULL}
99 };
100
101 static PyMethodDef PLy_exc_methods[] = {
102         {NULL, NULL, 0, NULL}
103 };
104
105 #if PY_MAJOR_VERSION >= 3
106 static PyModuleDef PLy_module = {
107         PyModuleDef_HEAD_INIT,          /* m_base */
108         "plpy",                                         /* m_name */
109         NULL,                                           /* m_doc */
110         -1,                                                     /* m_size */
111         PLy_methods,                            /* m_methods */
112 };
113
114 static PyModuleDef PLy_exc_module = {
115         PyModuleDef_HEAD_INIT,          /* m_base */
116         "spiexceptions",                        /* m_name */
117         NULL,                                           /* m_doc */
118         -1,                                                     /* m_size */
119         PLy_exc_methods,                        /* m_methods */
120         NULL,                                           /* m_reload */
121         NULL,                                           /* m_traverse */
122         NULL,                                           /* m_clear */
123         NULL                                            /* m_free */
124 };
125
126 /*
127  * Must have external linkage, because PyMODINIT_FUNC does dllexport on
128  * Windows-like platforms.
129  */
130 PyMODINIT_FUNC
131 PyInit_plpy(void)
132 {
133         PyObject   *m;
134
135         m = PyModule_Create(&PLy_module);
136         if (m == NULL)
137                 return NULL;
138
139         PLy_add_exceptions(m);
140
141         return m;
142 }
143 #endif   /* PY_MAJOR_VERSION >= 3 */
144
145 void
146 PLy_init_plpy(void)
147 {
148         PyObject   *main_mod,
149                            *main_dict,
150                            *plpy_mod;
151
152 #if PY_MAJOR_VERSION < 3
153         PyObject   *plpy;
154 #endif
155
156         /*
157          * initialize plpy module
158          */
159         PLy_plan_init_type();
160         PLy_result_init_type();
161         PLy_subtransaction_init_type();
162         PLy_cursor_init_type();
163
164 #if PY_MAJOR_VERSION >= 3
165         PyModule_Create(&PLy_module);
166         /* for Python 3 we initialized the exceptions in PyInit_plpy */
167 #else
168         plpy = Py_InitModule("plpy", PLy_methods);
169         PLy_add_exceptions(plpy);
170 #endif
171
172         /* PyDict_SetItemString(plpy, "PlanType", (PyObject *) &PLy_PlanType); */
173
174         /*
175          * initialize main module, and add plpy
176          */
177         main_mod = PyImport_AddModule("__main__");
178         main_dict = PyModule_GetDict(main_mod);
179         plpy_mod = PyImport_AddModule("plpy");
180         if (plpy_mod == NULL)
181                 PLy_elog(ERROR, "could not import \"plpy\" module");
182         PyDict_SetItemString(main_dict, "plpy", plpy_mod);
183         if (PyErr_Occurred())
184                 PLy_elog(ERROR, "could not import \"plpy\" module");
185 }
186
187 static void
188 PLy_add_exceptions(PyObject *plpy)
189 {
190         PyObject   *excmod;
191         HASHCTL         hash_ctl;
192
193 #if PY_MAJOR_VERSION < 3
194         excmod = Py_InitModule("spiexceptions", PLy_exc_methods);
195 #else
196         excmod = PyModule_Create(&PLy_exc_module);
197 #endif
198         if (excmod == NULL)
199                 PLy_elog(ERROR, "could not create the spiexceptions module");
200
201         /*
202          * PyModule_AddObject does not add a refcount to the object, for some odd
203          * reason; we must do that.
204          */
205         Py_INCREF(excmod);
206         if (PyModule_AddObject(plpy, "spiexceptions", excmod) < 0)
207                 PLy_elog(ERROR, "could not add the spiexceptions module");
208
209         PLy_exc_error = PLy_create_exception("plpy.Error", NULL, NULL,
210                                                                                  "Error", plpy);
211         PLy_exc_fatal = PLy_create_exception("plpy.Fatal", NULL, NULL,
212                                                                                  "Fatal", plpy);
213         PLy_exc_spi_error = PLy_create_exception("plpy.SPIError", NULL, NULL,
214                                                                                          "SPIError", plpy);
215
216         memset(&hash_ctl, 0, sizeof(hash_ctl));
217         hash_ctl.keysize = sizeof(int);
218         hash_ctl.entrysize = sizeof(PLyExceptionEntry);
219         PLy_spi_exceptions = hash_create("PL/Python SPI exceptions", 256,
220                                                                          &hash_ctl, HASH_ELEM | HASH_BLOBS);
221
222         PLy_generate_spi_exceptions(excmod, PLy_exc_spi_error);
223 }
224
225 /*
226  * Create an exception object and add it to the module
227  */
228 static PyObject *
229 PLy_create_exception(char *name, PyObject *base, PyObject *dict,
230                                          const char *modname, PyObject *mod)
231 {
232         PyObject   *exc;
233
234         exc = PyErr_NewException(name, base, dict);
235         if (exc == NULL)
236                 PLy_elog(ERROR, "could not create exception \"%s\"", name);
237
238         /*
239          * PyModule_AddObject does not add a refcount to the object, for some odd
240          * reason; we must do that.
241          */
242         Py_INCREF(exc);
243         PyModule_AddObject(mod, modname, exc);
244
245         /*
246          * The caller will also store a pointer to the exception object in some
247          * permanent variable, so add another ref to account for that.  This is
248          * probably excessively paranoid, but let's be sure.
249          */
250         Py_INCREF(exc);
251         return exc;
252 }
253
254 /*
255  * Add all the autogenerated exceptions as subclasses of SPIError
256  */
257 static void
258 PLy_generate_spi_exceptions(PyObject *mod, PyObject *base)
259 {
260         int                     i;
261
262         for (i = 0; exception_map[i].name != NULL; i++)
263         {
264                 bool            found;
265                 PyObject   *exc;
266                 PLyExceptionEntry *entry;
267                 PyObject   *sqlstate;
268                 PyObject   *dict = PyDict_New();
269
270                 if (dict == NULL)
271                         PLy_elog(ERROR, "could not generate SPI exceptions");
272
273                 sqlstate = PyString_FromString(unpack_sql_state(exception_map[i].sqlstate));
274                 if (sqlstate == NULL)
275                         PLy_elog(ERROR, "could not generate SPI exceptions");
276
277                 PyDict_SetItemString(dict, "sqlstate", sqlstate);
278                 Py_DECREF(sqlstate);
279
280                 exc = PLy_create_exception(exception_map[i].name, base, dict,
281                                                                    exception_map[i].classname, mod);
282
283                 entry = hash_search(PLy_spi_exceptions, &exception_map[i].sqlstate,
284                                                         HASH_ENTER, &found);
285                 Assert(!found);
286                 entry->exc = exc;
287         }
288 }
289
290
291 /*
292  * the python interface to the elog function
293  * don't confuse these with PLy_elog
294  */
295 static PyObject *PLy_output(volatile int level, PyObject *self,
296                    PyObject *args, PyObject *kw);
297
298 static PyObject *
299 PLy_debug(PyObject *self, PyObject *args, PyObject *kw)
300 {
301         return PLy_output(DEBUG2, self, args, kw);
302 }
303
304 static PyObject *
305 PLy_log(PyObject *self, PyObject *args, PyObject *kw)
306 {
307         return PLy_output(LOG, self, args, kw);
308 }
309
310 static PyObject *
311 PLy_info(PyObject *self, PyObject *args, PyObject *kw)
312 {
313         return PLy_output(INFO, self, args, kw);
314 }
315
316 static PyObject *
317 PLy_notice(PyObject *self, PyObject *args, PyObject *kw)
318 {
319         return PLy_output(NOTICE, self, args, kw);
320 }
321
322 static PyObject *
323 PLy_warning(PyObject *self, PyObject *args, PyObject *kw)
324 {
325         return PLy_output(WARNING, self, args, kw);
326 }
327
328 static PyObject *
329 PLy_error(PyObject *self, PyObject *args, PyObject *kw)
330 {
331         return PLy_output(ERROR, self, args, kw);
332 }
333
334 static PyObject *
335 PLy_fatal(PyObject *self, PyObject *args, PyObject *kw)
336 {
337         return PLy_output(FATAL, self, args, kw);
338 }
339
340 static PyObject *
341 PLy_quote_literal(PyObject *self, PyObject *args)
342 {
343         const char *str;
344         char       *quoted;
345         PyObject   *ret;
346
347         if (!PyArg_ParseTuple(args, "s:quote_literal", &str))
348                 return NULL;
349
350         quoted = quote_literal_cstr(str);
351         ret = PyString_FromString(quoted);
352         pfree(quoted);
353
354         return ret;
355 }
356
357 static PyObject *
358 PLy_quote_nullable(PyObject *self, PyObject *args)
359 {
360         const char *str;
361         char       *quoted;
362         PyObject   *ret;
363
364         if (!PyArg_ParseTuple(args, "z:quote_nullable", &str))
365                 return NULL;
366
367         if (str == NULL)
368                 return PyString_FromString("NULL");
369
370         quoted = quote_literal_cstr(str);
371         ret = PyString_FromString(quoted);
372         pfree(quoted);
373
374         return ret;
375 }
376
377 static PyObject *
378 PLy_quote_ident(PyObject *self, PyObject *args)
379 {
380         const char *str;
381         const char *quoted;
382         PyObject   *ret;
383
384         if (!PyArg_ParseTuple(args, "s:quote_ident", &str))
385                 return NULL;
386
387         quoted = quote_identifier(str);
388         ret = PyString_FromString(quoted);
389
390         return ret;
391 }
392
393 /* enforce cast of object to string */
394 static char *
395 object_to_string(PyObject *obj)
396 {
397         if (obj)
398         {
399                 PyObject   *so = PyObject_Str(obj);
400
401                 if (so != NULL)
402                 {
403                         char       *str;
404
405                         str = pstrdup(PyString_AsString(so));
406                         Py_DECREF(so);
407
408                         return str;
409                 }
410         }
411
412         return NULL;
413 }
414
415 static PyObject *
416 PLy_output(volatile int level, PyObject *self, PyObject *args, PyObject *kw)
417 {
418         int                     sqlstate = 0;
419         char       *volatile sqlstatestr = NULL;
420         char       *volatile message = NULL;
421         char       *volatile detail = NULL;
422         char       *volatile hint = NULL;
423         char       *volatile column_name = NULL;
424         char       *volatile constraint_name = NULL;
425         char       *volatile datatype_name = NULL;
426         char       *volatile table_name = NULL;
427         char       *volatile schema_name = NULL;
428         volatile MemoryContext oldcontext;
429         PyObject   *key,
430                            *value;
431         PyObject   *volatile so;
432         Py_ssize_t      pos = 0;
433
434         if (PyTuple_Size(args) == 1)
435         {
436                 /*
437                  * Treat single argument specially to avoid undesirable ('tuple',)
438                  * decoration.
439                  */
440                 PyObject   *o;
441
442                 if (!PyArg_UnpackTuple(args, "plpy.elog", 1, 1, &o))
443                         PLy_elog(ERROR, "could not unpack arguments in plpy.elog");
444                 so = PyObject_Str(o);
445         }
446         else
447                 so = PyObject_Str(args);
448
449         if (so == NULL || ((message = PyString_AsString(so)) == NULL))
450         {
451                 level = ERROR;
452                 message = dgettext(TEXTDOMAIN, "could not parse error message in plpy.elog");
453         }
454         message = pstrdup(message);
455
456         Py_XDECREF(so);
457
458         if (kw != NULL)
459         {
460                 while (PyDict_Next(kw, &pos, &key, &value))
461                 {
462                         char       *keyword = PyString_AsString(key);
463
464                         if (strcmp(keyword, "message") == 0)
465                         {
466                                 /* the message should not be overwriten */
467                                 if (PyTuple_Size(args) != 0)
468                                 {
469                                         PLy_exception_set(PyExc_TypeError, "Argument 'message' given by name and position");
470                                         return NULL;
471                                 }
472
473                                 if (message)
474                                         pfree(message);
475                                 message = object_to_string(value);
476                         }
477                         else if (strcmp(keyword, "detail") == 0)
478                                 detail = object_to_string(value);
479                         else if (strcmp(keyword, "hint") == 0)
480                                 hint = object_to_string(value);
481                         else if (strcmp(keyword, "sqlstate") == 0)
482                                 sqlstatestr = object_to_string(value);
483                         else if (strcmp(keyword, "schema_name") == 0)
484                                 schema_name = object_to_string(value);
485                         else if (strcmp(keyword, "table_name") == 0)
486                                 table_name = object_to_string(value);
487                         else if (strcmp(keyword, "column_name") == 0)
488                                 column_name = object_to_string(value);
489                         else if (strcmp(keyword, "datatype_name") == 0)
490                                 datatype_name = object_to_string(value);
491                         else if (strcmp(keyword, "constraint_name") == 0)
492                                 constraint_name = object_to_string(value);
493                         else
494                         {
495                                 PLy_exception_set(PyExc_TypeError,
496                                          "'%s' is an invalid keyword argument for this function",
497                                                                   keyword);
498                                 return NULL;
499                         }
500                 }
501         }
502
503         if (sqlstatestr != NULL)
504         {
505                 if (strlen(sqlstatestr) != 5)
506                 {
507                         PLy_exception_set(PyExc_ValueError, "invalid SQLSTATE code");
508                         return NULL;
509                 }
510
511                 if (strspn(sqlstatestr, "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ") != 5)
512                 {
513                         PLy_exception_set(PyExc_ValueError, "invalid SQLSTATE code");
514                         return NULL;
515                 }
516
517                 sqlstate = MAKE_SQLSTATE(sqlstatestr[0],
518                                                                  sqlstatestr[1],
519                                                                  sqlstatestr[2],
520                                                                  sqlstatestr[3],
521                                                                  sqlstatestr[4]);
522         }
523
524         oldcontext = CurrentMemoryContext;
525         PG_TRY();
526         {
527                 if (message != NULL)
528                         pg_verifymbstr(message, strlen(message), false);
529                 if (detail != NULL)
530                         pg_verifymbstr(detail, strlen(detail), false);
531                 if (hint != NULL)
532                         pg_verifymbstr(hint, strlen(hint), false);
533                 if (schema_name != NULL)
534                         pg_verifymbstr(schema_name, strlen(schema_name), false);
535                 if (table_name != NULL)
536                         pg_verifymbstr(table_name, strlen(table_name), false);
537                 if (column_name != NULL)
538                         pg_verifymbstr(column_name, strlen(column_name), false);
539                 if (datatype_name != NULL)
540                         pg_verifymbstr(datatype_name, strlen(datatype_name), false);
541                 if (constraint_name != NULL)
542                         pg_verifymbstr(constraint_name, strlen(constraint_name), false);
543
544                 ereport(level,
545                                 ((sqlstate != 0) ? errcode(sqlstate) : 0,
546                                  (message != NULL) ? errmsg_internal("%s", message) : 0,
547                                  (detail != NULL) ? errdetail_internal("%s", detail) : 0,
548                                  (hint != NULL) ? errhint("%s", hint) : 0,
549                                  (column_name != NULL) ?
550                                  err_generic_string(PG_DIAG_COLUMN_NAME, column_name) : 0,
551                                  (constraint_name != NULL) ?
552                         err_generic_string(PG_DIAG_CONSTRAINT_NAME, constraint_name) : 0,
553                                  (datatype_name != NULL) ?
554                                  err_generic_string(PG_DIAG_DATATYPE_NAME, datatype_name) : 0,
555                                  (table_name != NULL) ?
556                                  err_generic_string(PG_DIAG_TABLE_NAME, table_name) : 0,
557                                  (schema_name != NULL) ?
558                                  err_generic_string(PG_DIAG_SCHEMA_NAME, schema_name) : 0));
559         }
560         PG_CATCH();
561         {
562                 ErrorData  *edata;
563
564                 MemoryContextSwitchTo(oldcontext);
565                 edata = CopyErrorData();
566                 FlushErrorState();
567
568                 PLy_exception_set_with_details(PLy_exc_error, edata);
569                 FreeErrorData(edata);
570
571                 return NULL;
572         }
573         PG_END_TRY();
574
575         /*
576          * return a legal object so the interpreter will continue on its merry way
577          */
578         Py_INCREF(Py_None);
579         return Py_None;
580 }