Line data Source code
1 : /*
2 : * SPDX-License-Identifier: MPL-2.0
3 : *
4 : * This Source Code Form is subject to the terms of the Mozilla Public
5 : * License, v. 2.0. If a copy of the MPL was not distributed with this
6 : * file, You can obtain one at http://mozilla.org/MPL/2.0/.
7 : *
8 : * Copyright 2024 MonetDB Foundation;
9 : * Copyright August 2008 - 2023 MonetDB B.V.;
10 : * Copyright 1997 - July 2008 CWI.
11 : */
12 :
13 : #include "monetdb_config.h"
14 : #include "pyapi.h"
15 : #include "conversion.h"
16 : #include "connection.h"
17 : #include "emit.h"
18 :
19 : #include "unicode.h"
20 : #include "pytypes.h"
21 : #include "type_conversion.h"
22 : #include "formatinput.h"
23 :
24 10 : static void _loader_import_array(void) { _import_array(); }
25 :
26 10 : str _loader_init(void)
27 : {
28 10 : str msg = MAL_SUCCEED;
29 10 : _loader_import_array();
30 10 : msg = _emit_init();
31 10 : if (msg != MAL_SUCCEED) {
32 : return msg;
33 : }
34 :
35 10 : if (PyType_Ready(&Py_ConnectionType) < 0)
36 0 : return createException(MAL, "pyapi3.eval",
37 : SQLSTATE(PY000) "Failed to initialize loader functions.");
38 : return msg;
39 : }
40 :
41 : static int
42 72 : pyapi_list_length(list *l)
43 : {
44 72 : if (l)
45 48 : return l->cnt;
46 : return 0;
47 : }
48 :
49 : str
50 25 : PYAPI3PyAPIevalLoader(Client cntxt, MalBlkPtr mb, MalStkPtr stk, InstrPtr pci) {
51 25 : sql_func * sqlfun;
52 25 : sql_subfunc * sqlmorefun;
53 25 : str exprStr;
54 :
55 25 : const int additional_columns = 2;
56 25 : int i = 1, ai = 0;
57 25 : char *pycall = NULL;
58 25 : str *args = NULL;
59 25 : char *msg = MAL_SUCCEED;
60 25 : node *argnode, *n, *n2;
61 25 : PyObject *pArgs = NULL, *pEmit = NULL,
62 : *pConnection; // this is going to be the parameter tuple
63 25 : PyObject *code_object = NULL;
64 25 : sql_emit_col *cols = NULL;
65 25 : bool gstate = 0;
66 25 : int unnamedArgs = 0;
67 25 : int argcount = pci->argc;
68 25 : bool create_table = false;
69 25 : BUN nval = 0;
70 25 : int ncols = 0;
71 :
72 25 : char *loader_additional_args[] = {"_emit", "_conn"};
73 :
74 25 : if (!PYAPI3PyAPIInitialized()) {
75 0 : throw(MAL, "pyapi3.eval",
76 : SQLSTATE(PY000) "Embedded Python is enabled but an error was thrown during initialization.");
77 : }
78 25 : sqlmorefun = *(sql_subfunc**) getArgReference(stk, pci, pci->retc);
79 25 : sqlfun = sqlmorefun->func;
80 25 : exprStr = *getArgReference_str(stk, pci, pci->retc + 1);
81 :
82 25 : args = (str *)GDKzalloc(pci->argc * sizeof(str));
83 25 : if (!args) {
84 0 : throw(MAL, "pyapi3.eval", SQLSTATE(HY013) MAL_MALLOC_FAIL " arguments.");
85 : }
86 :
87 : // Analyse the SQL_Func structure to get the parameter names
88 25 : if (sqlfun != NULL && sqlfun->ops->cnt > 0) {
89 22 : unnamedArgs = pci->retc + 2;
90 22 : argnode = sqlfun->ops->h;
91 66 : while (argnode) {
92 44 : char *argname = ((sql_arg *)argnode->data)->name;
93 44 : args[unnamedArgs++] = GDKstrdup(argname);
94 44 : argnode = argnode->next;
95 : }
96 : }
97 :
98 : // We name all the unknown arguments
99 69 : for (i = pci->retc + 2; i < argcount; i++) {
100 44 : if (!args[i]) {
101 0 : char argbuf[64];
102 0 : snprintf(argbuf, sizeof(argbuf), "arg%i", i - pci->retc - 1);
103 0 : args[i] = GDKstrdup(argbuf);
104 : }
105 : }
106 25 : gstate = Python_ObtainGIL();
107 :
108 25 : pArgs = PyTuple_New(argcount - pci->retc - 2 + additional_columns);
109 25 : if (!pArgs) {
110 0 : msg = createException(MAL, "pyapi3.eval_loader",
111 : SQLSTATE(HY013) MAL_MALLOC_FAIL "python object");
112 0 : goto wrapup;
113 : }
114 :
115 25 : ai = 0;
116 25 : argnode = sqlfun && sqlfun->ops->cnt > 0 ? sqlfun->ops->h : NULL;
117 69 : for (i = pci->retc + 2; i < argcount; i++) {
118 44 : PyInput inp;
119 44 : PyObject *val = NULL;
120 44 : inp.bat = NULL;
121 44 : inp.sql_subtype = NULL;
122 :
123 44 : if (!isaBatType(getArgType(mb, pci, i))) {
124 44 : inp.scalar = true;
125 44 : inp.bat_type = getArgType(mb, pci, i);
126 44 : inp.count = 1;
127 44 : if (inp.bat_type == TYPE_str) {
128 18 : inp.dataptr = getArgReference_str(stk, pci, i);
129 : } else {
130 26 : inp.dataptr = getArgReference(stk, pci, i);
131 : }
132 44 : val = PyArrayObject_FromScalar(&inp, &msg);
133 : } else {
134 0 : BAT* b = BATdescriptor(*getArgReference_bat(stk, pci, i));
135 0 : if (b == NULL) {
136 0 : msg = createException(
137 : MAL, "pyapi3.eval_loader",
138 : SQLSTATE(PY000) "The BAT passed to the function (argument #%d) is NULL.\n",
139 0 : i - (pci->retc + 2) + 1);
140 0 : goto wrapup;
141 : }
142 0 : inp.scalar = false;
143 0 : inp.count = BATcount(b);
144 0 : inp.bat_type = getBatType(getArgType(mb, pci, i));
145 0 : inp.bat = b;
146 :
147 0 : val = PyMaskedArray_FromBAT(
148 : &inp, 0, inp.count, &msg,
149 : false);
150 0 : BBPunfix(inp.bat->batCacheid);
151 : }
152 44 : if (msg != MAL_SUCCEED) {
153 0 : goto wrapup;
154 : }
155 44 : if (PyTuple_SetItem(pArgs, ai++, val) != 0) {
156 0 : msg =
157 0 : createException(MAL, "pyapi3.eval_loader",
158 : SQLSTATE(PY000) "Failed to set tuple (this shouldn't happen).");
159 0 : goto wrapup;
160 : }
161 : // TODO deal with sql types
162 : }
163 :
164 25 : getArg(pci, 0) = TYPE_void;
165 25 : if (sqlmorefun->colnames) {
166 24 : n = sqlmorefun->colnames->h;
167 24 : n2 = sqlmorefun->coltypes->h;
168 24 : ncols = pyapi_list_length(sqlmorefun->coltypes);
169 24 : if (ncols == 0) {
170 0 : msg = createException(MAL, "pyapi3.eval_loader",
171 : "No columns supplied.");
172 0 : goto wrapup;
173 : }
174 24 : cols = GDKzalloc(sizeof(sql_emit_col) * ncols);
175 24 : if (!cols) {
176 0 : msg = createException(MAL, "pyapi3.eval_loader",
177 : SQLSTATE(HY013) MAL_MALLOC_FAIL "column list");
178 0 : goto wrapup;
179 : }
180 72 : assert(pyapi_list_length(sqlmorefun->colnames) == pyapi_list_length(sqlmorefun->coltypes) * 2);
181 : i = 0;
182 158 : while (n) {
183 134 : sql_subtype* tpe = (sql_subtype*) n2->data;
184 134 : cols[i].name = GDKstrdup(*((char **)n->data));
185 134 : n = n->next;
186 134 : assert(n);
187 134 : cols[i].def = n->data;
188 134 : n = n->next;
189 134 : cols[i].b = COLnew(0, tpe->type->localtype, 0, TRANSIENT);
190 134 : if (cols[i].b == NULL || cols[i].name == NULL) {
191 0 : do {
192 0 : BBPreclaim(cols[i].b);
193 0 : GDKfree(cols[i].name);
194 0 : } while (i-- > 0);
195 0 : msg = createException(MAL, "pyapi3.eval", GDK_EXCEPTION);
196 0 : goto wrapup;
197 : }
198 134 : n2 = n2->next;
199 134 : cols[i].b->tnil = false;
200 134 : cols[i].b->tnonil = false;
201 134 : i++;
202 : }
203 : } else {
204 : // set the return value to the correct type to prevent MAL layers from
205 : // complaining
206 : cols = NULL;
207 : ncols = 0;
208 : create_table = true;
209 : }
210 :
211 25 : pConnection = Py_Connection_Create(cntxt, 0, 0);
212 25 : pEmit = PyEmit_Create(cols, ncols);
213 25 : if (!pConnection || !pEmit) {
214 0 : msg = createException(MAL, "pyapi3.eval_loader",
215 : SQLSTATE(HY013) MAL_MALLOC_FAIL "python object");
216 0 : goto wrapup;
217 : }
218 :
219 25 : PyTuple_SetItem(pArgs, ai++, pEmit);
220 25 : PyTuple_SetItem(pArgs, ai++, pConnection);
221 :
222 25 : pycall = FormatCode(exprStr, args, argcount, 4, &code_object, &msg,
223 : loader_additional_args, additional_columns);
224 25 : if (!pycall && !code_object) {
225 0 : if (msg == MAL_SUCCEED) {
226 0 : msg = createException(MAL, "pyapi3.eval_loader",
227 : SQLSTATE(PY000) "Error while parsing Python code.");
228 : }
229 0 : goto wrapup;
230 : }
231 :
232 : {
233 25 : PyObject *pFunc, *pModule, *v, *d, *ret;
234 :
235 : // First we will load the main module, this is required
236 25 : pModule = PyImport_AddModule("__main__");
237 25 : if (!pModule) {
238 0 : msg = PyError_CreateException("Failed to load module", NULL);
239 0 : goto wrapup;
240 : }
241 :
242 : // Now we will add the UDF to the main module
243 25 : d = PyModule_GetDict(pModule);
244 25 : if (code_object == NULL) {
245 25 : v = PyRun_StringFlags(pycall, Py_file_input, d, NULL, NULL);
246 25 : if (v == NULL) {
247 0 : msg = PyError_CreateException("Could not parse Python code",
248 : pycall);
249 0 : goto wrapup;
250 : }
251 25 : Py_DECREF(v);
252 :
253 : // Now we need to obtain a pointer to the function, the function is
254 : // called "pyfun"
255 25 : pFunc = PyObject_GetAttrString(pModule, "pyfun");
256 25 : if (!pFunc || !PyCallable_Check(pFunc)) {
257 0 : msg = PyError_CreateException("Failed to load function", NULL);
258 0 : goto wrapup;
259 : }
260 : } else {
261 0 : pFunc = PyFunction_New(code_object, d);
262 0 : if (!pFunc || !PyCallable_Check(pFunc)) {
263 0 : msg = PyError_CreateException("Failed to load function", NULL);
264 0 : goto wrapup;
265 : }
266 : }
267 25 : ret = PyObject_CallObject(pFunc, pArgs);
268 :
269 25 : if (PyErr_Occurred()) {
270 12 : Py_DECREF(pFunc);
271 12 : msg = PyError_CreateException("Python exception", pycall);
272 12 : if (code_object == NULL) {
273 12 : PyRun_SimpleString("del pyfun");
274 : }
275 12 : goto wrapup;
276 : }
277 :
278 13 : if (ret != Py_None) {
279 0 : if (PyEmit_Emit((PyEmitObject *)pEmit, ret) == NULL) {
280 0 : Py_DECREF(pFunc);
281 0 : msg = PyError_CreateException("Python exception", pycall);
282 0 : goto wrapup;
283 : }
284 : }
285 :
286 13 : cols = ((PyEmitObject *)pEmit)->cols;
287 13 : nval = ((PyEmitObject *)pEmit)->nvals;
288 13 : ncols = (int)((PyEmitObject *)pEmit)->ncols;
289 13 : Py_DECREF(pFunc);
290 13 : Py_DECREF(pArgs);
291 13 : pArgs = NULL;
292 :
293 13 : if (ncols == 0) {
294 0 : msg = createException(MAL, "pyapi3.eval_loader",
295 : SQLSTATE(PY000) "No elements emitted by the loader.");
296 0 : goto wrapup;
297 : }
298 : }
299 :
300 13 : gstate = Python_ReleaseGIL(gstate);
301 :
302 78 : for (i = 0; i < ncols; i++) {
303 52 : BAT *b = cols[i].b;
304 52 : BATsetcount(b, nval);
305 52 : b->tkey = false;
306 52 : b->tsorted = false;
307 52 : b->trevsorted = false;
308 : }
309 13 : if (!create_table) {
310 12 : msg = _connection_append_to_table(cntxt, sqlmorefun->sname,
311 : sqlmorefun->tname, cols, ncols);
312 12 : goto wrapup;
313 : } else {
314 1 : msg = _connection_create_table(cntxt, sqlmorefun->sname,
315 : sqlmorefun->tname, cols, ncols);
316 1 : goto wrapup;
317 : }
318 :
319 25 : wrapup:
320 25 : if (cols) {
321 161 : for (i = 0; i < ncols; i++) {
322 136 : BBPreclaim(cols[i].b);
323 136 : if (cols[i].name) {
324 136 : GDKfree(cols[i].name);
325 : }
326 : }
327 25 : GDKfree(cols);
328 : }
329 25 : if (gstate) {
330 12 : if (pArgs) {
331 12 : Py_DECREF(pArgs);
332 : }
333 12 : gstate = Python_ReleaseGIL(gstate);
334 : }
335 25 : if (pycall)
336 25 : GDKfree(pycall);
337 25 : if (args) {
338 69 : for (i = pci->retc + 2; i < argcount; i++) {
339 44 : if (args[i]) {
340 44 : GDKfree(args[i]);
341 : }
342 : }
343 25 : GDKfree(args);
344 : }
345 25 : return (msg);
346 : }
|