LCOV - code coverage report
Current view: top level - sql/backends/monet5/UDF/pyapi3 - pyloader3.c (source / functions) Hit Total Coverage
Test: coverage.info Lines: 150 204 73.5 %
Date: 2024-10-03 20:03:20 Functions: 2 2 100.0 %

          Line data    Source code
       1             : /*
       2             :  * SPDX-License-Identifier: MPL-2.0
       3             :  *
       4             :  * This Source Code Form is subject to the terms of the Mozilla Public
       5             :  * License, v. 2.0.  If a copy of the MPL was not distributed with this
       6             :  * file, You can obtain one at http://mozilla.org/MPL/2.0/.
       7             :  *
       8             :  * Copyright 2024 MonetDB Foundation;
       9             :  * Copyright August 2008 - 2023 MonetDB B.V.;
      10             :  * Copyright 1997 - July 2008 CWI.
      11             :  */
      12             : 
      13             : #include "monetdb_config.h"
      14             : #include "pyapi.h"
      15             : #include "conversion.h"
      16             : #include "connection.h"
      17             : #include "emit.h"
      18             : 
      19             : #include "unicode.h"
      20             : #include "pytypes.h"
      21             : #include "type_conversion.h"
      22             : #include "formatinput.h"
      23             : 
      24          10 : static void _loader_import_array(void) { _import_array(); }
      25             : 
      26          10 : str _loader_init(void)
      27             : {
      28          10 :         str msg = MAL_SUCCEED;
      29          10 :         _loader_import_array();
      30          10 :         msg = _emit_init();
      31          10 :         if (msg != MAL_SUCCEED) {
      32             :                 return msg;
      33             :         }
      34             : 
      35          10 :         if (PyType_Ready(&Py_ConnectionType) < 0)
      36           0 :                 return createException(MAL, "pyapi3.eval",
      37             :                                                            SQLSTATE(PY000) "Failed to initialize loader functions.");
      38             :         return msg;
      39             : }
      40             : 
      41             : static int
      42          72 : pyapi_list_length(list *l)
      43             : {
      44          72 :         if (l)
      45          48 :                 return l->cnt;
      46             :         return 0;
      47             : }
      48             : 
      49             : str
      50          25 : PYAPI3PyAPIevalLoader(Client cntxt, MalBlkPtr mb, MalStkPtr stk, InstrPtr pci) {
      51          25 :     sql_func * sqlfun;
      52          25 :     sql_subfunc * sqlmorefun;
      53          25 :     str exprStr;
      54             : 
      55          25 :         const int additional_columns = 2;
      56          25 :         int i = 1, ai = 0;
      57          25 :         char *pycall = NULL;
      58          25 :         str *args = NULL;
      59          25 :         char *msg = MAL_SUCCEED;
      60          25 :         node *argnode, *n, *n2;
      61          25 :         PyObject *pArgs = NULL, *pEmit = NULL,
      62             :                          *pConnection; // this is going to be the parameter tuple
      63          25 :         PyObject *code_object = NULL;
      64          25 :         sql_emit_col *cols = NULL;
      65          25 :         bool gstate = 0;
      66          25 :         int unnamedArgs = 0;
      67          25 :         int argcount = pci->argc;
      68          25 :         bool create_table = false;
      69          25 :         BUN nval = 0;
      70          25 :         int ncols = 0;
      71             : 
      72          25 :         char *loader_additional_args[] = {"_emit", "_conn"};
      73             : 
      74          25 :         if (!PYAPI3PyAPIInitialized()) {
      75           0 :                 throw(MAL, "pyapi3.eval",
      76             :                                 SQLSTATE(PY000) "Embedded Python is enabled but an error was thrown during initialization.");
      77             :         }
      78          25 :         sqlmorefun = *(sql_subfunc**) getArgReference(stk, pci, pci->retc);
      79          25 :         sqlfun = sqlmorefun->func;
      80          25 :         exprStr = *getArgReference_str(stk, pci, pci->retc + 1);
      81             : 
      82          25 :         args = (str *)GDKzalloc(pci->argc * sizeof(str));
      83          25 :         if (!args) {
      84           0 :                 throw(MAL, "pyapi3.eval", SQLSTATE(HY013) MAL_MALLOC_FAIL " arguments.");
      85             :         }
      86             : 
      87             :         // Analyse the SQL_Func structure to get the parameter names
      88          25 :         if (sqlfun != NULL && sqlfun->ops->cnt > 0) {
      89          22 :                 unnamedArgs = pci->retc + 2;
      90          22 :                 argnode = sqlfun->ops->h;
      91          66 :                 while (argnode) {
      92          44 :                         char *argname = ((sql_arg *)argnode->data)->name;
      93          44 :                         args[unnamedArgs++] = GDKstrdup(argname);
      94          44 :                         argnode = argnode->next;
      95             :                 }
      96             :         }
      97             : 
      98             :         // We name all the unknown arguments
      99          69 :         for (i = pci->retc + 2; i < argcount; i++) {
     100          44 :                 if (!args[i]) {
     101           0 :                         char argbuf[64];
     102           0 :                         snprintf(argbuf, sizeof(argbuf), "arg%i", i - pci->retc - 1);
     103           0 :                         args[i] = GDKstrdup(argbuf);
     104             :                 }
     105             :         }
     106          25 :         gstate = Python_ObtainGIL();
     107             : 
     108          25 :         pArgs = PyTuple_New(argcount - pci->retc - 2 + additional_columns);
     109          25 :         if (!pArgs) {
     110           0 :                 msg = createException(MAL, "pyapi3.eval_loader",
     111             :                                                           SQLSTATE(HY013) MAL_MALLOC_FAIL "python object");
     112           0 :                 goto wrapup;
     113             :         }
     114             : 
     115          25 :         ai = 0;
     116          25 :         argnode = sqlfun && sqlfun->ops->cnt > 0 ? sqlfun->ops->h : NULL;
     117          69 :         for (i = pci->retc + 2; i < argcount; i++) {
     118          44 :                 PyInput inp;
     119          44 :                 PyObject *val = NULL;
     120          44 :                 inp.bat = NULL;
     121          44 :                 inp.sql_subtype = NULL;
     122             : 
     123          44 :                 if (!isaBatType(getArgType(mb, pci, i))) {
     124          44 :                         inp.scalar = true;
     125          44 :                         inp.bat_type = getArgType(mb, pci, i);
     126          44 :                         inp.count = 1;
     127          44 :                         if (inp.bat_type == TYPE_str) {
     128          18 :                                 inp.dataptr = getArgReference_str(stk, pci, i);
     129             :                         } else {
     130          26 :                                 inp.dataptr = getArgReference(stk, pci, i);
     131             :                         }
     132          44 :                         val = PyArrayObject_FromScalar(&inp, &msg);
     133             :                 } else {
     134           0 :                         BAT* b = BATdescriptor(*getArgReference_bat(stk, pci, i));
     135           0 :                         if (b == NULL) {
     136           0 :                                 msg = createException(
     137             :                                         MAL, "pyapi3.eval_loader",
     138             :                                         SQLSTATE(PY000) "The BAT passed to the function (argument #%d) is NULL.\n",
     139           0 :                                         i - (pci->retc + 2) + 1);
     140           0 :                                 goto wrapup;
     141             :                         }
     142           0 :                         inp.scalar = false;
     143           0 :                         inp.count = BATcount(b);
     144           0 :                         inp.bat_type = getBatType(getArgType(mb, pci, i));
     145           0 :                         inp.bat = b;
     146             : 
     147           0 :                         val = PyMaskedArray_FromBAT(
     148             :                                 &inp, 0, inp.count, &msg,
     149             :                                 false);
     150           0 :                         BBPunfix(inp.bat->batCacheid);
     151             :                 }
     152          44 :                 if (msg != MAL_SUCCEED) {
     153           0 :                         goto wrapup;
     154             :                 }
     155          44 :                 if (PyTuple_SetItem(pArgs, ai++, val) != 0) {
     156           0 :                         msg =
     157           0 :                                 createException(MAL, "pyapi3.eval_loader",
     158             :                                                                 SQLSTATE(PY000) "Failed to set tuple (this shouldn't happen).");
     159           0 :                         goto wrapup;
     160             :                 }
     161             :                 // TODO deal with sql types
     162             :         }
     163             : 
     164          25 :         getArg(pci, 0) = TYPE_void;
     165          25 :         if (sqlmorefun->colnames) {
     166          24 :                 n = sqlmorefun->colnames->h;
     167          24 :                 n2 = sqlmorefun->coltypes->h;
     168          24 :                 ncols = pyapi_list_length(sqlmorefun->coltypes);
     169          24 :                 if (ncols == 0) {
     170           0 :                         msg = createException(MAL, "pyapi3.eval_loader",
     171             :                                                                   "No columns supplied.");
     172           0 :                         goto wrapup;
     173             :                 }
     174          24 :                 cols = GDKzalloc(sizeof(sql_emit_col) * ncols);
     175          24 :                 if (!cols) {
     176           0 :                         msg = createException(MAL, "pyapi3.eval_loader",
     177             :                                                                   SQLSTATE(HY013) MAL_MALLOC_FAIL "column list");
     178           0 :                         goto wrapup;
     179             :                 }
     180          72 :                 assert(pyapi_list_length(sqlmorefun->colnames) == pyapi_list_length(sqlmorefun->coltypes) * 2);
     181             :                 i = 0;
     182         158 :                 while (n) {
     183         134 :                         sql_subtype* tpe = (sql_subtype*) n2->data;
     184         134 :                         cols[i].name = GDKstrdup(*((char **)n->data));
     185         134 :                         n = n->next;
     186         134 :                         assert(n);
     187         134 :                         cols[i].def = n->data;
     188         134 :                         n = n->next;
     189         134 :                         cols[i].b = COLnew(0, tpe->type->localtype, 0, TRANSIENT);
     190         134 :                         if (cols[i].b == NULL || cols[i].name == NULL) {
     191           0 :                                 do {
     192           0 :                                         BBPreclaim(cols[i].b);
     193           0 :                                         GDKfree(cols[i].name);
     194           0 :                                 } while (i-- > 0);
     195           0 :                                 msg = createException(MAL, "pyapi3.eval", GDK_EXCEPTION);
     196           0 :                                 goto wrapup;
     197             :                         }
     198         134 :                         n2 = n2->next;
     199         134 :                         cols[i].b->tnil = false;
     200         134 :                         cols[i].b->tnonil = false;
     201         134 :                         i++;
     202             :                 }
     203             :         } else {
     204             :                 // set the return value to the correct type to prevent MAL layers from
     205             :                 // complaining
     206             :                 cols = NULL;
     207             :                 ncols = 0;
     208             :                 create_table = true;
     209             :         }
     210             : 
     211          25 :         pConnection = Py_Connection_Create(cntxt, 0, 0);
     212          25 :         pEmit = PyEmit_Create(cols, ncols);
     213          25 :         if (!pConnection || !pEmit) {
     214           0 :                 msg = createException(MAL, "pyapi3.eval_loader",
     215             :                                                           SQLSTATE(HY013) MAL_MALLOC_FAIL "python object");
     216           0 :                 goto wrapup;
     217             :         }
     218             : 
     219          25 :         PyTuple_SetItem(pArgs, ai++, pEmit);
     220          25 :         PyTuple_SetItem(pArgs, ai++, pConnection);
     221             : 
     222          25 :         pycall = FormatCode(exprStr, args, argcount, 4, &code_object, &msg,
     223             :                                                 loader_additional_args, additional_columns);
     224          25 :         if (!pycall && !code_object) {
     225           0 :                 if (msg == MAL_SUCCEED) {
     226           0 :                         msg = createException(MAL, "pyapi3.eval_loader",
     227             :                                                                   SQLSTATE(PY000) "Error while parsing Python code.");
     228             :                 }
     229           0 :                 goto wrapup;
     230             :         }
     231             : 
     232             :         {
     233          25 :                 PyObject *pFunc, *pModule, *v, *d, *ret;
     234             : 
     235             :                 // First we will load the main module, this is required
     236          25 :                 pModule = PyImport_AddModule("__main__");
     237          25 :                 if (!pModule) {
     238           0 :                         msg = PyError_CreateException("Failed to load module", NULL);
     239           0 :                         goto wrapup;
     240             :                 }
     241             : 
     242             :                 // Now we will add the UDF to the main module
     243          25 :                 d = PyModule_GetDict(pModule);
     244          25 :                 if (code_object == NULL) {
     245          25 :                         v = PyRun_StringFlags(pycall, Py_file_input, d, NULL, NULL);
     246          25 :                         if (v == NULL) {
     247           0 :                                 msg = PyError_CreateException("Could not parse Python code",
     248             :                                                                                           pycall);
     249           0 :                                 goto wrapup;
     250             :                         }
     251          25 :                         Py_DECREF(v);
     252             : 
     253             :                         // Now we need to obtain a pointer to the function, the function is
     254             :                         // called "pyfun"
     255          25 :                         pFunc = PyObject_GetAttrString(pModule, "pyfun");
     256          25 :                         if (!pFunc || !PyCallable_Check(pFunc)) {
     257           0 :                                 msg = PyError_CreateException("Failed to load function", NULL);
     258           0 :                                 goto wrapup;
     259             :                         }
     260             :                 } else {
     261           0 :                         pFunc = PyFunction_New(code_object, d);
     262           0 :                         if (!pFunc || !PyCallable_Check(pFunc)) {
     263           0 :                                 msg = PyError_CreateException("Failed to load function", NULL);
     264           0 :                                 goto wrapup;
     265             :                         }
     266             :                 }
     267          25 :                 ret = PyObject_CallObject(pFunc, pArgs);
     268             : 
     269          25 :                 if (PyErr_Occurred()) {
     270          12 :                         Py_DECREF(pFunc);
     271          12 :                         msg = PyError_CreateException("Python exception", pycall);
     272          12 :                         if (code_object == NULL) {
     273          12 :                                 PyRun_SimpleString("del pyfun");
     274             :                         }
     275          12 :                         goto wrapup;
     276             :                 }
     277             : 
     278          13 :                 if (ret != Py_None) {
     279           0 :                         if (PyEmit_Emit((PyEmitObject *)pEmit, ret) == NULL) {
     280           0 :                                 Py_DECREF(pFunc);
     281           0 :                                 msg = PyError_CreateException("Python exception", pycall);
     282           0 :                                 goto wrapup;
     283             :                         }
     284             :                 }
     285             : 
     286          13 :                 cols = ((PyEmitObject *)pEmit)->cols;
     287          13 :                 nval = ((PyEmitObject *)pEmit)->nvals;
     288          13 :                 ncols = (int)((PyEmitObject *)pEmit)->ncols;
     289          13 :                 Py_DECREF(pFunc);
     290          13 :                 Py_DECREF(pArgs);
     291          13 :                 pArgs = NULL;
     292             : 
     293          13 :                 if (ncols == 0) {
     294           0 :                         msg = createException(MAL, "pyapi3.eval_loader",
     295             :                                                                   SQLSTATE(PY000) "No elements emitted by the loader.");
     296           0 :                         goto wrapup;
     297             :                 }
     298             :         }
     299             : 
     300          13 :         gstate = Python_ReleaseGIL(gstate);
     301             : 
     302          78 :         for (i = 0; i < ncols; i++) {
     303          52 :                 BAT *b = cols[i].b;
     304          52 :                 BATsetcount(b, nval);
     305          52 :                 b->tkey = false;
     306          52 :                 b->tsorted = false;
     307          52 :                 b->trevsorted = false;
     308             :         }
     309          13 :         if (!create_table) {
     310          12 :                 msg = _connection_append_to_table(cntxt, sqlmorefun->sname,
     311             :                                                                            sqlmorefun->tname, cols, ncols);
     312          12 :                 goto wrapup;
     313             :         } else {
     314           1 :                 msg = _connection_create_table(cntxt, sqlmorefun->sname,
     315             :                                                                            sqlmorefun->tname, cols, ncols);
     316           1 :                 goto wrapup;
     317             :         }
     318             : 
     319          25 : wrapup:
     320          25 :         if (cols) {
     321         161 :                 for (i = 0; i < ncols; i++) {
     322         136 :                         BBPreclaim(cols[i].b);
     323         136 :                         if (cols[i].name) {
     324         136 :                                 GDKfree(cols[i].name);
     325             :                         }
     326             :                 }
     327          25 :                 GDKfree(cols);
     328             :         }
     329          25 :         if (gstate) {
     330          12 :                 if (pArgs) {
     331          12 :                         Py_DECREF(pArgs);
     332             :                 }
     333          12 :                 gstate = Python_ReleaseGIL(gstate);
     334             :         }
     335          25 :         if (pycall)
     336          25 :                 GDKfree(pycall);
     337          25 :         if (args) {
     338          69 :                 for (i = pci->retc + 2; i < argcount; i++) {
     339          44 :                         if (args[i]) {
     340          44 :                                 GDKfree(args[i]);
     341             :                         }
     342             :                 }
     343          25 :                 GDKfree(args);
     344             :         }
     345          25 :         return (msg);
     346             : }

Generated by: LCOV version 1.14