Line data Source code
1 : /*
2 : * SPDX-License-Identifier: MPL-2.0
3 : *
4 : * This Source Code Form is subject to the terms of the Mozilla Public
5 : * License, v. 2.0. If a copy of the MPL was not distributed with this
6 : * file, You can obtain one at http://mozilla.org/MPL/2.0/.
7 : *
8 : * Copyright 2024 MonetDB Foundation;
9 : * Copyright August 2008 - 2023 MonetDB B.V.;
10 : * Copyright 1997 - July 2008 CWI.
11 : */
12 :
13 : #include "monetdb_config.h"
14 : #include "opt_for.h"
15 :
16 : #if 0
17 : static InstrPtr
18 : ReplaceWithNil(MalBlkPtr mb, InstrPtr p, int pos)
19 : {
20 : p = pushNilBat(mb, p); /* push at end */
21 : getArg(p, pos) = getArg(p, p->argc - 1);
22 : p->argc--;
23 : return p;
24 : }
25 : #endif
26 :
27 : static bool
28 1 : allConstExcept(MalBlkPtr mb, InstrPtr p, int except)
29 : {
30 2 : for (int j = p->retc; j < p->argc; j++) {
31 2 : if (j != except && getArgType(mb, p, j) >= TYPE_any)
32 : return false;
33 : }
34 : return true;
35 : }
36 :
37 : str
38 568947 : OPTforImplementation(Client cntxt, MalBlkPtr mb, MalStkPtr stk, InstrPtr pci)
39 : {
40 568947 : int i, j, k, limit, slimit;
41 568947 : InstrPtr p = 0, *old = NULL;
42 568947 : int actions = 0;
43 568947 : int *varisfor = NULL, *varforvalue = NULL;
44 568947 : str msg = MAL_SUCCEED;
45 :
46 568947 : (void) cntxt;
47 568947 : (void) stk; /* to fool compilers */
48 :
49 568947 : if (mb->inlineProp)
50 0 : goto wrapup;
51 :
52 568947 : limit = mb->stop;
53 :
54 22576265 : for (i = 0; i < limit; i++) {
55 22007325 : p = mb->stmt[i];
56 22007325 : if (p && p->retc == 1 && getModuleId(p) == forRef
57 12 : && getFunctionId(p) == decompressRef) {
58 : break;
59 : }
60 : }
61 568947 : if (i == limit)
62 568940 : goto wrapup; /* nothing to do */
63 :
64 7 : varisfor = GDKzalloc(2 * mb->vtop * sizeof(int));
65 7 : varforvalue = GDKzalloc(2 * mb->vtop * sizeof(int));
66 7 : if (varisfor == NULL || varforvalue == NULL)
67 0 : goto wrapup;
68 :
69 7 : slimit = mb->ssize;
70 7 : old = mb->stmt;
71 7 : if (newMalBlkStmt(mb, mb->ssize) < 0) {
72 0 : GDKfree(varisfor);
73 0 : GDKfree(varforvalue);
74 0 : throw(MAL, "optimizer.for", SQLSTATE(HY013) MAL_MALLOC_FAIL);
75 : }
76 : // Consolidate the actual need for variables
77 503 : for (i = 0; i < limit; i++) {
78 496 : p = old[i];
79 496 : if (p == 0)
80 0 : continue; //left behind by others?
81 496 : if (p->retc == 1 && getModuleId(p) == forRef
82 8 : && getFunctionId(p) == decompressRef) {
83 : // remember we have encountered a for decompress function
84 8 : k = getArg(p, 0);
85 8 : varisfor[k] = getArg(p, 1);
86 8 : varforvalue[k] = getArg(p, 2);
87 8 : freeInstruction(p);
88 8 : continue;
89 : }
90 1489 : int done = 0;
91 1489 : for (j = p->retc; j < p->argc; j++) {
92 1017 : k = getArg(p, j);
93 1017 : if (varisfor[k]) { // maybe we could delay this usage
94 27 : if (getModuleId(p) == algebraRef
95 19 : && getFunctionId(p) == projectionRef) {
96 : /* projection(cand, col) with col = for.decompress(o,min_val)
97 : * v1 = projection(cand, o)
98 : * for.decompress(v1, min_val) */
99 14 : InstrPtr r = copyInstruction(p);
100 14 : if (r == NULL) {
101 0 : msg = createException(MAL, "optimizer.for",
102 : SQLSTATE(HY013) MAL_MALLOC_FAIL);
103 0 : break;
104 : }
105 14 : int tpe = getVarType(mb, varisfor[k]);
106 14 : int l = getArg(r, 0);
107 14 : getArg(r, 0) = newTmpVariable(mb, tpe);
108 14 : getArg(r, j) = varisfor[k];
109 14 : varisfor[l] = getArg(r, 0);
110 14 : varforvalue[l] = varforvalue[k];
111 14 : pushInstruction(mb, r);
112 14 : freeInstruction(p);
113 14 : done = 1;
114 14 : break;
115 13 : } else if (p->argc == 2 && p->retc == 1
116 1 : && p->barrier == ASSIGNsymbol) {
117 : /* a = b */
118 0 : int l = getArg(p, 0);
119 0 : varisfor[l] = varisfor[k];
120 0 : varforvalue[l] = varforvalue[k];
121 0 : freeInstruction(p);
122 0 : done = 1;
123 0 : break;
124 13 : } else if (getModuleId(p) == algebraRef
125 5 : && getFunctionId(p) == subsliceRef) {
126 : /* pos = subslice(col, l, h) with col = for.decompress(o,min_val)
127 : * pos = subslice(o, l, h) */
128 0 : InstrPtr r = copyInstruction(p);
129 0 : if (r == NULL) {
130 0 : msg = createException(MAL, "optimizer.for",
131 : SQLSTATE(HY013) MAL_MALLOC_FAIL);
132 0 : break;
133 : }
134 0 : getArg(r, j) = varisfor[k];
135 0 : pushInstruction(mb, r);
136 0 : freeInstruction(p);
137 0 : done = 1;
138 0 : break;
139 13 : } else if ((getModuleId(p) == batRef
140 0 : && getFunctionId(p) == mirrorRef)
141 13 : || (getModuleId(p) == batcalcRef
142 2 : && getFunctionId(p) == identityRef)) {
143 : /* id = mirror/identity(col) with col = for.decompress(o,min_val)
144 : * id = mirror/identity(o) */
145 0 : InstrPtr r = copyInstruction(p);
146 0 : if (r == NULL) {
147 0 : msg = createException(MAL, "optimizer.for",
148 : SQLSTATE(HY013) MAL_MALLOC_FAIL);
149 0 : break;
150 : }
151 0 : getArg(r, j) = varisfor[k];
152 0 : pushInstruction(mb, r);
153 0 : freeInstruction(p);
154 0 : done = 1;
155 0 : break;
156 13 : } else if (getFunctionId(p) == thetaselectRef) {
157 : /* pos = thetaselect(col, cand, l, ...) with col = for.decompress(o, minval)
158 : * l = calc.-(l, minval);
159 : * nl = calc.bte(l);
160 : * or
161 : * nl = calc.sht(l);
162 : * pos = select(o, cand, nl, ...) */
163 :
164 1 : InstrPtr q = newInstructionArgs(mb, calcRef, minusRef, 3);
165 1 : if (q == NULL) {
166 0 : msg = createException(MAL, "optimizer.for",
167 : SQLSTATE(HY013) MAL_MALLOC_FAIL);
168 0 : break;
169 : }
170 1 : int tpe = getVarType(mb, getArg(p, 3));
171 1 : getArg(q, 0) = newTmpVariable(mb, tpe);
172 1 : q = pushArgument(mb, q, getArg(p, 3));
173 1 : q = pushArgument(mb, q, varforvalue[k]);
174 1 : pushInstruction(mb, q);
175 :
176 1 : InstrPtr r;
177 1 : tpe = getBatType(getVarType(mb, varisfor[k]));
178 1 : if (tpe == TYPE_bte)
179 1 : r = newInstructionArgs(mb, calcRef, putName("bte"), 2);
180 : else
181 0 : r = newInstructionArgs(mb, calcRef, putName("sht"), 2);
182 1 : if (r == NULL) {
183 0 : msg = createException(MAL, "optimizer.for",
184 : SQLSTATE(HY013) MAL_MALLOC_FAIL);
185 0 : break;
186 : }
187 1 : getArg(r, 0) = newTmpVariable(mb, tpe);
188 1 : r = pushArgument(mb, r, getArg(q, 0));
189 1 : pushInstruction(mb, r);
190 :
191 1 : q = copyInstruction(p);
192 1 : if (q == NULL) {
193 0 : msg = createException(MAL, "optimizer.for",
194 : SQLSTATE(HY013) MAL_MALLOC_FAIL);
195 0 : break;
196 : }
197 1 : getArg(q, j) = varisfor[k];
198 1 : getArg(q, 3) = getArg(r, 0);
199 1 : pushInstruction(mb, q);
200 1 : freeInstruction(p);
201 1 : done = 1;
202 1 : break;
203 : #if 0
204 : } else if (getFunctionId(p) == selectRef && p->argc == 9) {
205 : /* select (c, s, l, h, li, hi, anti, unknown ) */
206 : InstrPtr r = newInstructionArgs(mb, dictRef, selectRef, 10);
207 : if (r == NULL) {
208 : msg = createException(MAL, "optimizer.for",
209 : SQLSTATE(HY013) MAL_MALLOC_FAIL);
210 : break;
211 : }
212 :
213 : getArg(r, 0) = getArg(p, 0);
214 : r = pushArgument(mb, r, varisdict[k]);
215 : r = pushArgument(mb, r, getArg(p, 2)); /* cand */
216 : r = pushArgument(mb, r, vardictvalue[k]);
217 : r = pushArgument(mb, r, getArg(p, 3)); /* l */
218 : r = pushArgument(mb, r, getArg(p, 4)); /* h */
219 : r = pushArgument(mb, r, getArg(p, 5)); /* li */
220 : r = pushArgument(mb, r, getArg(p, 6)); /* hi */
221 : r = pushArgument(mb, r, getArg(p, 7)); /* anti */
222 : r = pushArgument(mb, r, getArg(p, 8)); /* unknown */
223 : pushInstruction(mb, r);
224 : freeInstruction(p);
225 : done = 1;
226 : break;
227 : } else if (isSelect(p)) {
228 : /* pos = select(col, cand, l, h, ...) with col = dict.decompress(o,u)
229 : * tp = select(u, nil, l, h, ...)
230 : * tp2 = batcalc.bte/sht/int(tp)
231 : * pos = intersect(o, tp2, cand, nil) */
232 :
233 : int cand = getArg(p, j + 1);
234 : InstrPtr r = copyInstruction(p);
235 : if (r == NULL) {
236 : msg = createException(MAL, "optimizer.for",
237 : SQLSTATE(HY013) MAL_MALLOC_FAIL);
238 : break;
239 : }
240 : getArg(r, j) = vardictvalue[k];
241 : if (cand)
242 : r = ReplaceWithNil(mb, r, j + 1); /* no candidate list */
243 : pushInstruction(mb, r);
244 :
245 : int tpe = getVarType(mb, varisdict[k]);
246 : InstrPtr s = newInstructionArgs(mb, dictRef, putName("convert"), 3);
247 : if (s == NULL) {
248 : msg = createException(MAL, "optimizer.for",
249 : SQLSTATE(HY013) MAL_MALLOC_FAIL);
250 : break;
251 : }
252 : getArg(s, 0) = newTmpVariable(mb, tpe);
253 : s = pushArgument(mb, s, getArg(r, 0));
254 : pushInstruction(mb, s);
255 :
256 : InstrPtr t = newInstructionArgs(mb, algebraRef, intersectRef, 9);
257 : if (t == NULL) {
258 : msg = createException(MAL, "optimizer.for",
259 : SQLSTATE(HY013) MAL_MALLOC_FAIL);
260 : break;
261 : }
262 : getArg(t, 0) = getArg(p, 0);
263 : t = pushArgument(mb, t, varisdict[k]);
264 : t = pushArgument(mb, t, getArg(s, 0));
265 : t = pushArgument(mb, t, cand);
266 : t = pushNilBat(mb, t);
267 : t = pushBit(mb, t, TRUE); /* nil matches */
268 : t = pushBit(mb, t, TRUE); /* max_one */
269 : t = pushNil(mb, t, TYPE_lng); /* estimate */
270 : pushInstruction(mb, t);
271 : freeInstruction(p);
272 : done = 1;
273 : break;
274 : #endif
275 12 : } else if ((isMapOp(p) || isMap2Op(p))
276 2 : && (getFunctionId(p) == plusRef
277 2 : || getFunctionId(p) == minusRef) && p->argc > 2
278 1 : && getBatType(getArgType(mb, p, 2)) != TYPE_oid
279 1 : && allConstExcept(mb, p, j)) {
280 : /* filter out unary batcalc.- with and without a candidate list */
281 : /* batcalc.-(1, col) with col = for.decompress(o,min_val)
282 : * v1 = calc.-(1, min_val)
283 : * for.decompress(o, v1) */
284 : /* we assume binary operators only ! */
285 0 : InstrPtr r = newInstructionArgs(mb, calcRef, getFunctionId(p), 3);
286 0 : if (r == NULL) {
287 0 : msg = createException(MAL, "optimizer.for",
288 : SQLSTATE(HY013) MAL_MALLOC_FAIL);
289 0 : break;
290 : }
291 0 : int tpe = getBatType(getVarType(mb, getArg(p, 0)));
292 0 : getArg(r, 0) = newTmpVariable(mb, tpe);
293 0 : int l = getArg(r, 0), m = getArg(p, 0);
294 0 : r = pushArgument(mb, r, getArg(p, 1));
295 0 : r = pushArgument(mb, r, getArg(p, 2));
296 0 : getArg(r, j) = varforvalue[k];
297 :
298 : /* new and old result are now min-values */
299 0 : varisfor[l] = varisfor[m] = varisfor[k];
300 0 : varforvalue[l] = varforvalue[m] = getArg(r, 0);
301 0 : pushInstruction(mb, r);
302 0 : freeInstruction(p);
303 0 : done = 1;
304 0 : break;
305 12 : } else if (getModuleId(p) == groupRef
306 1 : && (getFunctionId(p) == subgroupRef
307 1 : || getFunctionId(p) == subgroupdoneRef
308 1 : || getFunctionId(p) == groupRef
309 1 : || getFunctionId(p) == groupdoneRef)) {
310 : /* group.group[done](col) | group.subgroup[done](col, grp) with col = for.decompress(o,min_val)
311 : * v1 = group.group[done](o) | group.subgroup[done](o, grp) */
312 1 : int input = varisfor[k];
313 1 : InstrPtr r = copyInstruction(p);
314 1 : if (r == NULL) {
315 0 : msg = createException(MAL, "optimizer.for",
316 : SQLSTATE(HY013) MAL_MALLOC_FAIL);
317 0 : break;
318 : }
319 1 : getArg(r, j) = input;
320 1 : pushInstruction(mb, r);
321 1 : freeInstruction(p);
322 1 : done = 1;
323 1 : break;
324 : } else {
325 : /* need to decompress */
326 11 : int tpe = getArgType(mb, p, j);
327 11 : InstrPtr r = newInstructionArgs(mb, forRef, decompressRef, 3);
328 11 : if (r == NULL) {
329 0 : msg = createException(MAL, "optimizer.for",
330 : SQLSTATE(HY013) MAL_MALLOC_FAIL);
331 0 : break;
332 : }
333 11 : getArg(r, 0) = newTmpVariable(mb, tpe);
334 11 : r = pushArgument(mb, r, varisfor[k]);
335 11 : r = pushArgument(mb, r, varforvalue[k]);
336 11 : pushInstruction(mb, r);
337 :
338 11 : getArg(p, j) = getArg(r, 0);
339 11 : actions++;
340 : }
341 : }
342 : }
343 0 : if (msg)
344 : break;
345 488 : if (done)
346 16 : actions++;
347 : else
348 472 : pushInstruction(mb, p);
349 : }
350 :
351 1303 : for (; i < slimit; i++)
352 1296 : if (old[i])
353 0 : freeInstruction(old[i]);
354 : /* Defense line against incorrect plans */
355 7 : if (msg == MAL_SUCCEED && actions > 0) {
356 7 : msg = chkTypes(cntxt->usermodule, mb, FALSE);
357 7 : if (!msg)
358 7 : msg = chkFlow(mb);
359 7 : if (!msg)
360 7 : msg = chkDeclarations(mb);
361 : }
362 : /* keep all actions taken as a post block comment */
363 0 : wrapup:
364 : /* keep actions taken as a fake argument */
365 568947 : (void) pushInt(mb, pci, actions);
366 :
367 568994 : GDKfree(old);
368 568977 : GDKfree(varisfor);
369 569042 : GDKfree(varforvalue);
370 569042 : return msg;
371 : }
|