Line data Source code
1 : /*
2 : * SPDX-License-Identifier: MPL-2.0
3 : *
4 : * This Source Code Form is subject to the terms of the Mozilla Public
5 : * License, v. 2.0. If a copy of the MPL was not distributed with this
6 : * file, You can obtain one at http://mozilla.org/MPL/2.0/.
7 : *
8 : * Copyright 2024 MonetDB Foundation;
9 : * Copyright August 2008 - 2023 MonetDB B.V.;
10 : * Copyright 1997 - July 2008 CWI.
11 : */
12 :
13 : #include "monetdb_config.h"
14 : #include "opt_dict.h"
15 :
16 : static inline InstrPtr
17 2 : ReplaceWithNil(MalBlkPtr mb, InstrPtr p, int pos)
18 : {
19 2 : p = pushNilBat(mb, p); /* push at end */
20 2 : getArg(p, pos) = getArg(p, p->argc - 1);
21 2 : p->argc--;
22 2 : return p;
23 : }
24 :
25 : static inline bool
26 59 : allConstExcept(MalBlkPtr mb, InstrPtr p, int except)
27 : {
28 108 : for (int j = p->retc; j < p->argc; j++) {
29 93 : if (j != except && getArgType(mb, p, j) >= TYPE_any)
30 : return false;
31 : }
32 : return true;
33 : }
34 :
35 : str
36 568039 : OPTdictImplementation(Client cntxt, MalBlkPtr mb, MalStkPtr stk, InstrPtr pci)
37 : {
38 568039 : int i, j, k, limit, slimit;
39 568039 : InstrPtr p = NULL, *old = NULL;
40 568039 : int actions = 0;
41 568039 : int *varisdict = NULL, *vardictvalue = NULL;
42 568039 : bit *dictunique = NULL;
43 568039 : str msg = MAL_SUCCEED;
44 :
45 568039 : (void) cntxt;
46 568039 : (void) stk; /* to fool compilers */
47 :
48 568039 : if (mb->inlineProp)
49 0 : goto wrapup;
50 :
51 568039 : varisdict = GDKzalloc(2 * mb->vtop * sizeof(int));
52 569171 : vardictvalue = GDKzalloc(2 * mb->vtop * sizeof(int));
53 569191 : dictunique = GDKzalloc(2 * mb->vtop * sizeof(bit));
54 569085 : if (varisdict == NULL || vardictvalue == NULL || dictunique == NULL)
55 0 : goto wrapup;
56 :
57 569085 : limit = mb->stop;
58 569085 : slimit = mb->ssize;
59 569085 : old = mb->stmt;
60 569085 : if (newMalBlkStmt(mb, mb->ssize) < 0) {
61 0 : GDKfree(varisdict);
62 0 : GDKfree(vardictvalue);
63 0 : GDKfree(dictunique);
64 0 : throw(MAL, "optimizer.dict", SQLSTATE(HY013) MAL_MALLOC_FAIL);
65 : }
66 : /* Consolidate the actual need for variables */
67 22597579 : for (i = 0; mb->errors == NULL && i < limit; i++) {
68 22028438 : p = old[i];
69 22028438 : if (p == NULL)
70 0 : continue; /* left behind by others? */
71 22028438 : if (p->retc == 1 && getModuleId(p) == dictRef
72 222 : && getFunctionId(p) == decompressRef) {
73 : /* remember we have encountered a dict decompress function */
74 168 : k = getArg(p, 0);
75 168 : varisdict[k] = getArg(p, 1);
76 168 : vardictvalue[k] = getArg(p, 2);
77 168 : dictunique[k] = 1;
78 168 : freeInstruction(p);
79 168 : old[i] = NULL;
80 168 : continue;
81 : }
82 55820017 : bool done = false;
83 55820017 : for (j = p->retc; j < p->argc; j++) {
84 33791477 : k = getArg(p, j);
85 33791477 : if (varisdict[k]) { /* maybe we could delay this usage */
86 0 : if (getModuleId(p) == algebraRef
87 526 : && getFunctionId(p) == projectionRef) {
88 : /* projection(cand, col) with col = dict.decompress(o,u)
89 : * v1 = projection(cand, o)
90 : * dict.decompress(v1, u) */
91 352 : InstrPtr r = copyInstruction(p);
92 352 : if (r == NULL) {
93 0 : msg = createException(MAL, "optimizer.dict",
94 : SQLSTATE(HY013) MAL_MALLOC_FAIL);
95 0 : break;
96 : }
97 352 : int tpe = getVarType(mb, varisdict[k]);
98 352 : int l = getArg(r, 0);
99 352 : getArg(r, 0) = newTmpVariable(mb, tpe);
100 352 : getArg(r, j) = varisdict[k];
101 352 : varisdict[l] = getArg(r, 0);
102 352 : vardictvalue[l] = vardictvalue[k];
103 352 : dictunique[l] = dictunique[k];
104 352 : pushInstruction(mb, r);
105 352 : freeInstruction(p);
106 352 : old[i] = NULL;
107 352 : done = true;
108 352 : break;
109 0 : } else if (p->argc == 2 && p->retc == 1
110 16 : && p->barrier == ASSIGNsymbol) {
111 : /* a = b */
112 0 : int l = getArg(p, 0);
113 0 : varisdict[l] = varisdict[k];
114 0 : vardictvalue[l] = vardictvalue[k];
115 0 : dictunique[l] = dictunique[k];
116 0 : freeInstruction(p);
117 0 : old[i] = NULL;
118 0 : done = true;
119 0 : break;
120 0 : } else if (getModuleId(p) == algebraRef
121 174 : && getFunctionId(p) == subsliceRef) {
122 : /* pos = subslice(col, l, h) with col = dict.decompress(o,u)
123 : * pos = subslice(o, l, h) */
124 0 : InstrPtr r = copyInstruction(p);
125 0 : if (r == NULL) {
126 0 : msg = createException(MAL, "optimizer.dict",
127 : SQLSTATE(HY013) MAL_MALLOC_FAIL);
128 0 : break;
129 : }
130 0 : getArg(r, j) = varisdict[k];
131 0 : pushInstruction(mb, r);
132 0 : freeInstruction(p);
133 0 : old[i] = NULL;
134 0 : done = true;
135 0 : break;
136 0 : } else if ((getModuleId(p) == batRef
137 10 : && getFunctionId(p) == mirrorRef)
138 0 : || (getModuleId(p) == batcalcRef
139 50 : && getFunctionId(p) == identityRef)) {
140 : /* id = mirror/identity(col) with col = dict.decompress(o,u)
141 : * id = mirror/identity(o) */
142 8 : InstrPtr r = copyInstruction(p);
143 8 : if (r == NULL) {
144 0 : msg = createException(MAL, "optimizer.dict",
145 : SQLSTATE(HY013) MAL_MALLOC_FAIL);
146 0 : break;
147 : }
148 8 : getArg(r, j) = varisdict[k];
149 8 : pushInstruction(mb, r);
150 8 : freeInstruction(p);
151 8 : old[i] = NULL;
152 8 : done = true;
153 8 : break;
154 0 : } else if (isSelect(p)) {
155 110 : if (getFunctionId(p) == thetaselectRef) {
156 87 : InstrPtr r = newInstructionArgs(mb, dictRef, thetaselectRef, 6);
157 87 : if (r == NULL) {
158 0 : msg = createException(MAL, "optimizer.dict",
159 : SQLSTATE(HY013)
160 : MAL_MALLOC_FAIL);
161 0 : break;
162 : }
163 :
164 87 : getArg(r, 0) = getArg(p, 0);
165 87 : r = pushArgument(mb, r, varisdict[k]);
166 87 : r = pushArgument(mb, r, getArg(p, 2)); /* cand */
167 87 : r = pushArgument(mb, r, vardictvalue[k]);
168 87 : r = pushArgument(mb, r, getArg(p, 3)); /* val */
169 87 : r = pushArgument(mb, r, getArg(p, 4)); /* op */
170 87 : pushInstruction(mb, r);
171 44 : } else if (getFunctionId(p) == selectRef && p->argc == 9) {
172 : /* select (c, s, l, h, li, hi, anti, unknown ) */
173 21 : InstrPtr r = newInstructionArgs(mb, dictRef, selectRef, 10);
174 21 : if (r == NULL) {
175 0 : msg = createException(MAL, "optimizer.dict",
176 : SQLSTATE(HY013)
177 : MAL_MALLOC_FAIL);
178 0 : break;
179 : }
180 :
181 21 : getArg(r, 0) = getArg(p, 0);
182 21 : r = pushArgument(mb, r, varisdict[k]);
183 21 : r = pushArgument(mb, r, getArg(p, 2)); /* cand */
184 21 : r = pushArgument(mb, r, vardictvalue[k]);
185 21 : r = pushArgument(mb, r, getArg(p, 3)); /* l */
186 21 : r = pushArgument(mb, r, getArg(p, 4)); /* h */
187 21 : r = pushArgument(mb, r, getArg(p, 5)); /* li */
188 21 : r = pushArgument(mb, r, getArg(p, 6)); /* hi */
189 21 : r = pushArgument(mb, r, getArg(p, 7)); /* anti */
190 21 : r = pushArgument(mb, r, getArg(p, 8)); /* unknown */
191 21 : pushInstruction(mb, r);
192 : } else {
193 : /* pos = select(col, cand, l, h, ...) with col = dict.decompress(o,u)
194 : * tp = select(u, nil, l, h, ...)
195 : * tp2 = batcalc.bte/sht/int(tp)
196 : * pos = intersect(o, tp2, cand, nil) */
197 :
198 2 : int has_cand = getArgType(mb, p, 2) == newBatType(TYPE_oid);
199 2 : InstrPtr r = copyInstruction(p);
200 2 : InstrPtr s = newInstructionArgs(mb, dictRef, putName("convert"), 3);
201 2 : InstrPtr t = newInstructionArgs(mb, algebraRef, intersectRef, 9);
202 2 : if (r == NULL || s == NULL || t == NULL) {
203 0 : freeInstruction(r);
204 0 : freeInstruction(s);
205 0 : freeInstruction(t);
206 0 : msg = createException(MAL, "optimizer.dict",
207 : SQLSTATE(HY013)
208 : MAL_MALLOC_FAIL);
209 0 : break;
210 : }
211 :
212 2 : getArg(r, 0) = newTmpVariable(mb, newBatType(TYPE_oid));
213 2 : getArg(r, j) = vardictvalue[k];
214 2 : if (has_cand)
215 2 : r = ReplaceWithNil(mb, r, 2); /* no candidate list */
216 2 : pushInstruction(mb, r);
217 :
218 2 : int tpe = getVarType(mb, varisdict[k]);
219 2 : getArg(s, 0) = newTmpVariable(mb, tpe);
220 2 : s = pushArgument(mb, s, getArg(r, 0));
221 2 : pushInstruction(mb, s);
222 :
223 2 : getArg(t, 0) = getArg(p, 0);
224 2 : t = pushArgument(mb, t, varisdict[k]);
225 2 : t = pushArgument(mb, t, getArg(s, 0));
226 2 : if (has_cand)
227 2 : t = pushArgument(mb, t, getArg(p, 2));
228 : else
229 0 : t = pushNilBat(mb, t);
230 2 : t = pushNilBat(mb, t);
231 2 : t = pushBit(mb, t, TRUE); /* nil matches */
232 2 : t = pushBit(mb, t, TRUE); /* max_one */
233 2 : t = pushNil(mb, t, TYPE_lng); /* estimate */
234 2 : pushInstruction(mb, t);
235 : }
236 110 : freeInstruction(p);
237 110 : old[i] = NULL;
238 110 : done = true;
239 110 : break;
240 202 : } else if (j == 2 && p->argc > j + 1
241 33 : && getModuleId(p) == algebraRef
242 13 : && getFunctionId(p) == joinRef
243 7 : && varisdict[getArg(p, j + 1)]
244 4 : && vardictvalue[k] == vardictvalue[getArg(p, j + 1)]) {
245 : /* (r1, r2) = join(col1, col2, cand1, cand2, ...) with
246 : * col1 = dict.decompress(o1,u1), col2 = dict.decompress(o2,u2)
247 : * iff u1 == u2
248 : * (r1, r2) = algebra.join(o1, o2, cand1, cand2, ...) */
249 0 : int l = getArg(p, j + 1);
250 0 : InstrPtr r = copyInstruction(p);
251 0 : if (r == NULL) {
252 0 : msg = createException(MAL, "optimizer.dict",
253 : SQLSTATE(HY013) MAL_MALLOC_FAIL);
254 0 : break;
255 : }
256 0 : getArg(r, j + 0) = varisdict[k];
257 0 : getArg(r, j + 1) = varisdict[l];
258 0 : pushInstruction(mb, r);
259 0 : freeInstruction(p);
260 0 : old[i] = NULL;
261 0 : done = true;
262 0 : break;
263 42 : } else if (j == 2 && p->argc > j + 1
264 33 : && getModuleId(p) == algebraRef
265 13 : && getFunctionId(p) == joinRef
266 7 : && varisdict[getArg(p, j + 1)]
267 4 : && vardictvalue[k] != vardictvalue[getArg(p, j + 1)]) {
268 : /* (r1, r2) = join(col1, col2, cand1, cand2, ...) with
269 : * col1 = dict.decompress(o1,u1), col2 = dict.decompress(o2,u2)
270 : * (r1, r2) = dict.join(o1, u1, o2, u2, cand1, cand2, ...) */
271 4 : int l = getArg(p, j + 1);
272 4 : InstrPtr r = newInstructionArgs(mb, dictRef, joinRef, 10);
273 4 : if (r == NULL) {
274 0 : msg = createException(MAL, "optimizer.dict",
275 : SQLSTATE(HY013) MAL_MALLOC_FAIL);
276 0 : break;
277 : }
278 4 : assert(p->argc == 8);
279 4 : getArg(r, 0) = getArg(p, 0);
280 4 : r = pushReturn(mb, r, getArg(p, 1));
281 4 : r = pushArgument(mb, r, varisdict[k]);
282 4 : r = pushArgument(mb, r, vardictvalue[k]);
283 4 : r = pushArgument(mb, r, varisdict[l]);
284 4 : r = pushArgument(mb, r, vardictvalue[l]);
285 4 : r = pushArgument(mb, r, getArg(p, 4));
286 4 : r = pushArgument(mb, r, getArg(p, 5));
287 4 : r = pushArgument(mb, r, getArg(p, 6));
288 4 : r = pushArgument(mb, r, getArg(p, 7));
289 4 : pushInstruction(mb, r);
290 4 : freeInstruction(p);
291 4 : old[i] = NULL;
292 4 : done = true;
293 4 : break;
294 198 : } else if ((isMapOp(p) || isMap2Op(p))
295 59 : && allConstExcept(mb, p, j)) {
296 : /* batcalc.-(1, col) with col = dict.decompress(o,u)
297 : * v1 = batcalc.-(1, u)
298 : * dict.decompress(o, v1) */
299 15 : InstrPtr r = copyInstruction(p);
300 15 : if (r == NULL) {
301 0 : msg = createException(MAL, "optimizer.dict",
302 : SQLSTATE(HY013) MAL_MALLOC_FAIL);
303 0 : break;
304 : }
305 15 : int tpe = getVarType(mb, getArg(p, 0));
306 15 : int l = getArg(r, 0), m = getArg(p, 0);
307 15 : getArg(r, 0) = newTmpVariable(mb, tpe);
308 15 : getArg(r, j) = vardictvalue[k];
309 :
310 : /* new and old result are now dicts */
311 15 : varisdict[l] = varisdict[m] = varisdict[k];
312 15 : vardictvalue[l] = vardictvalue[m] = getArg(r, 0);
313 15 : dictunique[l] = 0;
314 15 : pushInstruction(mb, r);
315 15 : freeInstruction(p);
316 15 : old[i] = NULL;
317 15 : done = true;
318 15 : break;
319 183 : } else if (getModuleId(p) == groupRef
320 20 : && (getFunctionId(p) == subgroupRef
321 16 : || getFunctionId(p) == subgroupdoneRef
322 11 : || getFunctionId(p) == groupRef
323 8 : || getFunctionId(p) == groupdoneRef)) {
324 : /* group.group[done](col) | group.subgroup[done](col, grp) with col = dict.decompress(o,u)
325 : * v1 = group.group[done](o) | group.subgroup[done](o, grp) */
326 20 : int input = varisdict[k];
327 20 : if (!dictunique[k]) {
328 : /* make new dict and renumber the inputs */
329 :
330 7 : int tpe = getVarType(mb, varisdict[k]);
331 : /*(o,v) = compress(vardictvalue[k]); */
332 7 : InstrPtr r = newInstructionArgs(mb, dictRef, compressRef, 3);
333 7 : InstrPtr s = newInstructionArgs(mb, dictRef, renumberRef, 3);
334 7 : if (r == NULL || s == NULL) {
335 0 : freeInstruction(r);
336 0 : freeInstruction(s);
337 0 : msg = createException(MAL, "optimizer.dict",
338 : SQLSTATE(HY013)
339 : MAL_MALLOC_FAIL);
340 0 : break;
341 : }
342 : /* dynamic type problem ie could be bte or sht, use same type as input dict */
343 7 : getArg(r, 0) = newTmpVariable(mb, tpe);
344 7 : r = pushReturn(mb, r,
345 : newTmpVariable(mb,
346 7 : getArgType(mb, p, j)));
347 7 : r = pushArgument(mb, r, vardictvalue[k]);
348 7 : pushInstruction(mb, r);
349 :
350 : /* newvar = renumber(varisdict[k], o); */
351 7 : getArg(s, 0) = newTmpVariable(mb, tpe);
352 7 : s = pushArgument(mb, s, varisdict[k]);
353 7 : s = pushArgument(mb, s, getArg(r, 0));
354 7 : pushInstruction(mb, s);
355 :
356 7 : input = getArg(s, 0);
357 : }
358 20 : InstrPtr r = copyInstruction(p);
359 20 : if (r == NULL) {
360 0 : msg = createException(MAL, "optimizer.dict",
361 : SQLSTATE(HY013) MAL_MALLOC_FAIL);
362 0 : break;
363 : }
364 20 : getArg(r, j) = input;
365 20 : pushInstruction(mb, r);
366 20 : freeInstruction(p);
367 20 : old[i] = NULL;
368 20 : done = true;
369 20 : break;
370 : } else {
371 : /* need to decompress */
372 163 : int tpe = getArgType(mb, p, j);
373 163 : InstrPtr r = newInstructionArgs(mb, dictRef, decompressRef, 3);
374 163 : if (r == NULL) {
375 0 : msg = createException(MAL, "optimizer.dict",
376 : SQLSTATE(HY013) MAL_MALLOC_FAIL);
377 0 : break;
378 : }
379 163 : getArg(r, 0) = newTmpVariable(mb, tpe);
380 163 : r = pushArgument(mb, r, varisdict[k]);
381 163 : r = pushArgument(mb, r, vardictvalue[k]);
382 163 : pushInstruction(mb, r);
383 :
384 163 : getArg(p, j) = getArg(r, 0);
385 163 : actions++;
386 : }
387 : }
388 : }
389 0 : if (msg)
390 : break;
391 22029049 : if (done)
392 0 : actions++;
393 : else {
394 22030377 : pushInstruction(mb, p);
395 22029545 : old[i] = NULL;
396 : }
397 : }
398 :
399 125984835 : for (; i < slimit; i++)
400 125415679 : if (old[i])
401 0 : freeInstruction(old[i]);
402 : /* Defense line against incorrect plans */
403 569156 : if (msg == MAL_SUCCEED && actions > 0) {
404 82 : msg = chkTypes(cntxt->usermodule, mb, FALSE);
405 82 : if (!msg)
406 82 : msg = chkFlow(mb);
407 82 : if (!msg)
408 82 : msg = chkDeclarations(mb);
409 : }
410 : /* keep all actions taken as a post block comment */
411 569074 : wrapup:
412 : /* keep actions taken as a fake argument */
413 569156 : (void) pushInt(mb, pci, actions);
414 :
415 569122 : GDKfree(old);
416 569207 : GDKfree(varisdict);
417 569174 : GDKfree(vardictvalue);
418 569206 : GDKfree(dictunique);
419 569206 : return msg;
420 : }
|