Line data Source code
1 : /*
2 : * SPDX-License-Identifier: MPL-2.0
3 : *
4 : * This Source Code Form is subject to the terms of the Mozilla Public
5 : * License, v. 2.0. If a copy of the MPL was not distributed with this
6 : * file, You can obtain one at http://mozilla.org/MPL/2.0/.
7 : *
8 : * Copyright 2024 MonetDB Foundation;
9 : * Copyright August 2008 - 2023 MonetDB B.V.;
10 : * Copyright 1997 - July 2008 CWI.
11 : */
12 :
13 : /* Generic stream handling code such as init and close */
14 :
15 : #include "monetdb_config.h"
16 : #include "stream.h"
17 : #include "stream_internal.h"
18 : #ifdef HAVE_SYS_TIME_H
19 : #include <sys/time.h>
20 : #endif
21 : #ifdef HAVE_SYS_IOCTL_H
22 : #include <sys/ioctl.h>
23 : #endif
24 :
25 :
26 : /* ------------------------------------------------------------------ */
27 : /* streams working on a socket */
28 :
29 : static int
30 943940 : socket_getoob(const stream *s)
31 : {
32 943940 : SOCKET fd = s->stream_data.s;
33 : #ifdef HAVE_POLL
34 943940 : struct pollfd pfd = (struct pollfd) {
35 : .fd = fd,
36 : .events = POLLPRI,
37 : };
38 943940 : if (poll(&pfd, 1, 0) > 0)
39 : #else
40 : fd_set xfds;
41 : struct timeval t = (struct timeval) {
42 : .tv_sec = 0,
43 : .tv_usec = 0,
44 : };
45 : #ifndef _MSC_VER
46 : #ifdef FD_SETSIZE
47 : if (fd >= FD_SETSIZE)
48 : return 0;
49 : #endif
50 : #endif
51 : FD_ZERO(&xfds);
52 : FD_SET(fd, &xfds);
53 : if (select(
54 : #ifdef _MSC_VER
55 : 0, /* ignored on Windows */
56 : #else
57 : fd + 1,
58 : #endif
59 : NULL, NULL, &xfds, &t) > 0)
60 : #endif
61 : {
62 : #ifdef HAVE_POLL
63 100 : if (pfd.revents & (POLLHUP | POLLNVAL))
64 100 : return -1;
65 5 : if ((pfd.revents & POLLPRI) == 0)
66 : return -1;
67 : #else
68 : if (!FD_ISSET(fd, &xfds))
69 : return 0;
70 : #endif
71 : /* discard regular data until OOB mark */
72 : #ifndef _MSC_VER /* Windows has to be different... */
73 0 : for (;;) {
74 0 : int atmark = 0;
75 0 : char flush[100];
76 : #ifdef HAVE_SOCKATMARK
77 0 : if ((atmark = sockatmark(fd)) < 0) {
78 0 : perror("sockatmark");
79 0 : break;
80 : }
81 : #else
82 : if (ioctlsocket(fd, SIOCATMARK, &atmark) < 0) {
83 : perror("ioctl");
84 : break;
85 : }
86 : #endif
87 0 : if (atmark)
88 : break;
89 0 : if (recv(fd, flush, sizeof(flush), 0) < 0) {
90 0 : perror("recv");
91 0 : break;
92 : }
93 : }
94 : #endif
95 0 : char b = 0;
96 0 : switch (recv(fd, &b, 1, MSG_OOB)) {
97 : case 0:
98 : /* unexpectedly didn't receive a byte */
99 : break;
100 0 : case 1:
101 0 : return b;
102 0 : case -1:
103 0 : perror("recv OOB");
104 0 : return -1;
105 : }
106 : }
107 : return 0;
108 : }
109 :
110 : static int
111 0 : socket_putoob(const stream *s, char val)
112 : {
113 0 : SOCKET fd = s->stream_data.s;
114 0 : if (send(fd, &val, 1, MSG_OOB) == -1) {
115 0 : perror("send OOB");
116 0 : return -1;
117 : }
118 : return 0;
119 : }
120 :
121 : #ifdef HAVE_SYS_UN_H
122 : /* UNIX domain sockets do not support OOB messages, so we need to do
123 : * something different */
124 : #define OOBMSG0 '\377' /* the two special bytes we send as "OOB" */
125 : #define OOBMSG1 '\377'
126 :
127 : static int
128 8881786 : socket_getoob_unix(const stream *s)
129 : {
130 8881786 : SOCKET fd = s->stream_data.s;
131 : #ifdef HAVE_POLL
132 8881786 : struct pollfd pfd = (struct pollfd) {
133 : .fd = fd,
134 : .events = POLLIN,
135 : };
136 8881786 : if (poll(&pfd, 1, 0) > 0)
137 : #else
138 : fd_set fds;
139 : struct timeval t = (struct timeval) {
140 : .tv_sec = 0,
141 : .tv_usec = 0,
142 : };
143 : #ifndef _MSC_VER
144 : #ifdef FD_SETSIZE
145 : if (fd >= FD_SETSIZE)
146 : return 0;
147 : #endif
148 : #endif
149 : FD_ZERO(&fds);
150 : FD_SET(fd, &fds);
151 : if (select(
152 : #ifdef _MSC_VER
153 : 0, /* ignored on Windows */
154 : #else
155 : fd + 1,
156 : #endif
157 : &fds, NULL, NULL, &t) > 0)
158 : #endif
159 : {
160 7796 : char buf[3];
161 7796 : ssize_t nr;
162 7796 : nr = recv(fd, buf, 2, MSG_PEEK);
163 7796 : if (nr == 2 && buf[0] == OOBMSG0 && buf[1] == OOBMSG1) {
164 0 : nr = recv(fd, buf, 3, 0);
165 0 : if (nr == 3)
166 0 : return buf[2];
167 : }
168 : }
169 : return 0;
170 : }
171 :
172 : static int
173 0 : socket_putoob_unix(const stream *s, char val)
174 : {
175 0 : char buf[3] = {
176 : OOBMSG0,
177 : OOBMSG1,
178 : val,
179 : };
180 0 : if (send(s->stream_data.s, buf, 3, 0) == -1) {
181 0 : perror("send");
182 0 : return -1;
183 : }
184 : return 0;
185 : }
186 : #endif
187 :
188 : static ssize_t
189 1958746 : socket_write(stream *restrict s, const void *restrict buf, size_t elmsize, size_t cnt)
190 : {
191 1958746 : size_t size = elmsize * cnt, res = 0;
192 : #ifdef _MSC_VER
193 : int nr = 0;
194 : #else
195 1958746 : ssize_t nr = 0;
196 : #endif
197 :
198 1958746 : if (s->errkind != MNSTR_NO__ERROR)
199 : return -1;
200 :
201 1958746 : if (size == 0 || elmsize == 0)
202 0 : return (ssize_t) cnt;
203 :
204 1958746 : errno = 0;
205 3917328 : while (res < size &&
206 : (
207 : /* Windows send works on int, make sure the argument fits */
208 1958755 : ((nr = send(s->stream_data.s, (const char *) buf + res,
209 : #ifdef _MSC_VER
210 : (int) min(size - res, 1 << 16)
211 : #else
212 : size
213 : #endif
214 : , 0)) > 0)
215 162 : || (nr < 0 && /* syscall failed */
216 162 : s->timeout > 0 && /* potentially timeout */
217 : #ifdef _MSC_VER
218 : WSAGetLastError() == WSAEWOULDBLOCK &&
219 : #else
220 0 : (errno == EAGAIN
221 : #if EAGAIN != EWOULDBLOCK
222 : || errno == EWOULDBLOCK
223 : #endif
224 0 : ) && /* it was! */
225 : #endif
226 0 : s->timeout_func != NULL && /* callback function exists */
227 0 : !s->timeout_func(s->timeout_data)) /* callback says don't stop */
228 162 : || (nr < 0 &&
229 : #ifdef _MSC_VER
230 : WSAGetLastError() == WSAEINTR
231 : #else
232 162 : errno == EINTR
233 : #endif
234 : )) /* interrupted */
235 : ) {
236 1958582 : errno = 0;
237 : #ifdef _MSC_VER
238 : WSASetLastError(0);
239 : #endif
240 1958582 : if (nr > 0)
241 1958582 : res += (size_t) nr;
242 : }
243 1958735 : if (res >= elmsize)
244 1958573 : return (ssize_t) (res / elmsize);
245 162 : if (nr < 0) {
246 162 : if (s->timeout > 0 &&
247 : #ifdef _MSC_VER
248 : WSAGetLastError() == WSAEWOULDBLOCK
249 : #else
250 0 : (errno == EAGAIN
251 : #if EAGAIN != EWOULDBLOCK
252 : || errno == EWOULDBLOCK
253 : #endif
254 : )
255 : #endif
256 : )
257 0 : mnstr_set_error(s, MNSTR_TIMEOUT, NULL);
258 : else
259 162 : mnstr_set_error_errno(s, MNSTR_WRITE_ERROR, "socket write");
260 162 : return -1;
261 : }
262 : return 0;
263 : }
264 :
265 : static ssize_t
266 7322260 : socket_read(stream *restrict s, void *restrict buf, size_t elmsize, size_t cnt)
267 : {
268 : #ifdef _MSC_VER
269 : int nr = 0;
270 : int size;
271 : if (elmsize * cnt > INT_MAX)
272 : size = (int) (elmsize * (INT_MAX / elmsize));
273 : else
274 : size = (int) (elmsize * cnt);
275 : #else
276 7322260 : ssize_t nr = 0;
277 7322260 : size_t size = elmsize * cnt;
278 : #endif
279 :
280 7322260 : if (s->errkind != MNSTR_NO__ERROR)
281 : return -1;
282 7322260 : if (size == 0)
283 : return 0;
284 :
285 7331312 : for (;;) {
286 7326786 : if (s->timeout) {
287 6297081 : int ret;
288 : #ifdef HAVE_POLL
289 6297081 : struct pollfd pfd;
290 :
291 6297081 : pfd = (struct pollfd) {.fd = s->stream_data.s,
292 : .events = POLLIN};
293 : #ifdef HAVE_SYS_UN_H
294 6297081 : if (s->putoob != socket_putoob_unix)
295 5559897 : pfd.events |= POLLPRI;
296 : #endif
297 :
298 6297081 : ret = poll(&pfd, 1, (int) s->timeout);
299 6297086 : if (ret == -1) {
300 0 : if (errno == EINTR)
301 4526 : continue;
302 0 : mnstr_set_error_errno(s, MNSTR_READ_ERROR, "poll error");
303 35609 : return -1;
304 : }
305 6297086 : if (ret == 1) {
306 6292542 : if (pfd.revents & POLLHUP) {
307 : /* hung up, return EOF */
308 35591 : s->eof = true;
309 35591 : return 0;
310 : }
311 6256951 : if (pfd.revents & POLLPRI) {
312 : /* discard regular data until OOB mark */
313 0 : for (;;) {
314 0 : int atmark = 0;
315 0 : char flush[100];
316 : #ifdef HAVE_SOCKATMARK
317 0 : if ((atmark = sockatmark(s->stream_data.s)) < 0) {
318 0 : perror("sockatmark");
319 0 : break;
320 : }
321 : #else
322 : if (ioctlsocket(s->stream_data.s, SIOCATMARK, &atmark) < 0) {
323 : perror("ioctl");
324 : break;
325 : }
326 : #endif
327 0 : if (atmark)
328 : break;
329 0 : if (recv(s->stream_data.s, flush, sizeof(flush), 0) < 0) {
330 0 : perror("recv");
331 0 : break;
332 : }
333 : }
334 0 : char b = 0;
335 0 : switch (recv(s->stream_data.s, &b, 1, MSG_OOB)) {
336 0 : case 0:
337 : /* unexpectedly didn't receive a byte */
338 0 : continue;
339 0 : case 1:
340 0 : mnstr_set_error(s, MNSTR_INTERRUPT, "query abort from client");
341 0 : return -1;
342 0 : case -1:
343 0 : mnstr_set_error_errno(s, MNSTR_READ_ERROR, "recv error");
344 0 : return -1;
345 : }
346 : }
347 : }
348 : #else
349 : struct timeval tv;
350 : fd_set fds, xfds;
351 :
352 : errno = 0;
353 : #ifdef _MSC_VER
354 : WSASetLastError(0);
355 : #endif
356 : FD_ZERO(&fds);
357 : FD_SET(s->stream_data.s, &fds);
358 : FD_ZERO(&xfds);
359 : FD_SET(s->stream_data.s, &xfds);
360 : tv.tv_sec = s->timeout / 1000;
361 : tv.tv_usec = (s->timeout % 1000) * 1000;
362 : ret = select(
363 : #ifdef _MSC_VER
364 : 0, /* ignored on Windows */
365 : #else
366 : s->stream_data.s + 1,
367 : #endif
368 : &fds, NULL, &xfds, &tv);
369 : if (ret == SOCKET_ERROR) {
370 : mnstr_set_error_errno(s, MNSTR_READ_ERROR, "select");
371 : return -1;
372 : }
373 : if (ret > 0 && FD_ISSET(s->stream_data.s, &xfds)) {
374 : /* discard regular data until OOB mark */
375 : #ifndef _MSC_VER /* Windows has to be different... */
376 : for (;;) {
377 : int atmark = 0;
378 : char flush[100];
379 : #ifdef HAVE_SOCKATMARK
380 : if ((atmark = sockatmark(s->stream_data.s)) < 0) {
381 : perror("sockatmark");
382 : break;
383 : }
384 : #else
385 : if (ioctlsocket(s->stream_data.s, SIOCATMARK, &atmark) < 0) {
386 : perror("ioctl");
387 : break;
388 : }
389 : #endif
390 : if (atmark)
391 : break;
392 : if (recv(s->stream_data.s, flush, sizeof(flush), 0) < 0) {
393 : perror("recv");
394 : break;
395 : }
396 : }
397 : #endif
398 : char b = 0;
399 : switch (recv(s->stream_data.s, &b, 1, MSG_OOB)) {
400 : case 0:
401 : /* unexpectedly didn't receive a byte */
402 : continue;
403 : case 1:
404 : mnstr_set_error(s, MNSTR_INTERRUPT, "query abort from client");
405 : return -1;
406 : case -1:
407 : mnstr_set_error_errno(s, MNSTR_READ_ERROR, "recv error");
408 : return -1;
409 : }
410 : continue; /* try again */
411 : }
412 : #endif
413 6261495 : if (ret == 0) {
414 4544 : if (s->timeout_func == NULL || s->timeout_func(s->timeout_data)) {
415 18 : mnstr_set_error(s, MNSTR_TIMEOUT, NULL);
416 18 : return -1;
417 : }
418 4526 : continue;
419 : }
420 : #ifdef HAVE_POLL
421 6256951 : assert(pfd.revents & (POLLIN|POLLHUP));
422 : #else
423 : assert(FD_ISSET(s->stream_data.s, &fds) || FD_ISSET(s->stream_data.s, &xfds));
424 : #endif
425 : }
426 7286656 : nr = recv(s->stream_data.s, buf, size, 0);
427 7286622 : if (nr == SOCKET_ERROR) {
428 1 : mnstr_set_error_errno(s, errno == EINTR ? MNSTR_INTERRUPT : MNSTR_READ_ERROR, NULL);
429 1 : return -1;
430 : }
431 : #ifdef HAVE_SYS_UN_H
432 : /* when reading a block size in a block stream
433 : * (elmsize==2,cnt==1), we may actually get an "OOB" message
434 : * when this is a Unix domain socket */
435 7286621 : if (s->putoob == socket_putoob_unix &&
436 1053970 : elmsize == 2 && cnt == 1 && nr == 2 &&
437 497356 : ((char *)buf)[0] == OOBMSG0 &&
438 438 : ((char *)buf)[1] == OOBMSG1) {
439 : /* also read (and discard) the "pay load" */
440 0 : (void) recv(s->stream_data.s, buf, 1, 0);
441 0 : mnstr_set_error(s, MNSTR_INTERRUPT, "query abort from client");
442 0 : return -1;
443 : }
444 : #endif
445 7286621 : break;
446 : }
447 7286621 : if (nr == 0) {
448 3148 : s->eof = true;
449 3148 : return 0; /* end of file */
450 : }
451 7283473 : if (elmsize > 1) {
452 1008933 : while ((size_t) nr % elmsize != 0) {
453 : /* if elmsize > 1, we really expect that "the
454 : * other side" wrote complete items in a
455 : * single system call, so we expect to at
456 : * least receive complete items, and hence we
457 : * continue reading until we did in fact
458 : * receive an integral number of complete
459 : * items, ignoring any timeouts (but not real
460 : * errors) (note that recursion is limited
461 : * since we don't propagate the element size
462 : * to the recursive call) */
463 0 : ssize_t n;
464 0 : n = socket_read(s, (char *) buf + nr, 1, size - (size_t) nr);
465 0 : if (n < 0) {
466 0 : if (s->errkind == MNSTR_NO__ERROR)
467 0 : mnstr_set_error(s, MNSTR_READ_ERROR, "socket_read failed");
468 0 : return -1;
469 : }
470 0 : if (n == 0) /* unexpected end of file */
471 : break;
472 0 : nr +=
473 : #ifdef _MSC_VER
474 : (int)
475 : #endif
476 : n;
477 : }
478 : }
479 7283480 : return nr / (ssize_t) elmsize;
480 : }
481 :
482 : static void
483 80049 : socket_close(stream *s)
484 : {
485 80049 : SOCKET fd = s->stream_data.s;
486 :
487 80049 : if (fd != INVALID_SOCKET) {
488 : /* Related read/write (in/out, from/to) streams
489 : * share a single socket which is not dup'ed (anymore)
490 : * as Windows' dup doesn't work on sockets;
491 : * hence, only one of the streams must/may close that
492 : * socket; we choose to let the read socket do the
493 : * job, since in mapi.c it may happen that the read
494 : * stream is closed before the write stream was even
495 : * created.
496 : */
497 80049 : if (s->readonly) {
498 : #ifdef HAVE_SHUTDOWN
499 40025 : shutdown(fd, SHUT_RDWR);
500 : #endif
501 40025 : closesocket(fd);
502 : }
503 : }
504 80049 : s->stream_data.s = INVALID_SOCKET;
505 80049 : }
506 :
507 : static void
508 38436 : socket_update_timeout(stream *s)
509 : {
510 38436 : SOCKET fd = s->stream_data.s;
511 38436 : struct timeval tv;
512 :
513 38436 : if (fd == INVALID_SOCKET)
514 0 : return;
515 38436 : tv.tv_sec = s->timeout / 1000;
516 38436 : tv.tv_usec = (s->timeout % 1000) * 1000;
517 : /* cast to char * for Windows, no harm on "normal" systems */
518 38436 : if (s->readonly)
519 38436 : (void) setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, (char *) &tv, (socklen_t) sizeof(tv));
520 : else
521 0 : (void) setsockopt(fd, SOL_SOCKET, SO_SNDTIMEO, (char *) &tv, (socklen_t) sizeof(tv));
522 : }
523 :
524 : #ifndef MSG_DONTWAIT
525 : #define MSG_DONTWAIT 0
526 : #endif
527 :
528 : static int
529 0 : socket_isalive(const stream *s)
530 : {
531 0 : SOCKET fd = s->stream_data.s;
532 : #ifdef HAVE_POLL
533 0 : struct pollfd pfd;
534 0 : int ret;
535 0 : pfd = (struct pollfd){.fd = fd};
536 0 : if ((ret = poll(&pfd, 1, 0)) == 0)
537 : return 1;
538 0 : if (ret == -1 && errno == EINTR)
539 0 : return socket_isalive(s);
540 0 : if (ret < 0 || pfd.revents & (POLLERR | POLLHUP))
541 : return 0;
542 0 : assert(0); /* unexpected revents value */
543 : return 0;
544 : #else
545 : fd_set fds;
546 : struct timeval t;
547 : char buffer[32];
548 :
549 : t.tv_sec = 0;
550 : t.tv_usec = 0;
551 : FD_ZERO(&fds);
552 : FD_SET(fd, &fds);
553 : return select(
554 : #ifdef _MSC_VER
555 : 0, /* ignored on Windows */
556 : #else
557 : fd + 1,
558 : #endif
559 : &fds, NULL, NULL, &t) <= 0 ||
560 : recv(fd, buffer, sizeof(buffer), MSG_PEEK | MSG_DONTWAIT) != 0;
561 : #endif
562 : }
563 :
564 : static stream *
565 80064 : socket_open(SOCKET sock, const char *name)
566 : {
567 80064 : stream *s;
568 80064 : int domain = 0;
569 :
570 80064 : if (sock == INVALID_SOCKET) {
571 0 : mnstr_set_open_error(name, 0, "invalid socket");
572 0 : return NULL;
573 : }
574 80064 : if ((s = create_stream(name)) == NULL)
575 : return NULL;
576 80064 : s->read = socket_read;
577 80064 : s->write = socket_write;
578 80064 : s->close = socket_close;
579 80064 : s->stream_data.s = sock;
580 80064 : s->update_timeout = socket_update_timeout;
581 80064 : s->isalive = socket_isalive;
582 80064 : s->getoob = socket_getoob;
583 80064 : s->putoob = socket_putoob;
584 :
585 80064 : errno = 0;
586 : #ifdef _MSC_VER
587 : WSASetLastError(0);
588 : #endif
589 : #if defined(SO_DOMAIN)
590 : {
591 80064 : socklen_t len = (socklen_t) sizeof(domain);
592 80064 : if (getsockopt(sock, SOL_SOCKET, SO_DOMAIN, (void *) &domain, &len) == SOCKET_ERROR)
593 0 : domain = AF_INET; /* give it a value if call fails */
594 : }
595 : #else
596 : {
597 : struct sockaddr_storage a;
598 : socklen_t l = (socklen_t) sizeof(a);
599 : if (getpeername(sock, (struct sockaddr *) &a, &l) == 0) {
600 : domain = a.ss_family;
601 : }
602 : }
603 : #endif
604 : #ifdef HAVE_SYS_UN_H
605 80064 : if (domain == AF_UNIX) {
606 71114 : s->getoob = socket_getoob_unix;
607 71114 : s->putoob = socket_putoob_unix;
608 : }
609 : #endif
610 : #if defined(SO_KEEPALIVE) && !defined(WIN32)
611 80064 : if (domain != AF_UNIX) { /* not on UNIX sockets */
612 8950 : int opt = 1;
613 8950 : (void) setsockopt(sock, SOL_SOCKET, SO_KEEPALIVE, (void *) &opt, sizeof(opt));
614 : }
615 : #endif
616 : #if defined(IPTOS_THROUGHPUT) && !defined(WIN32)
617 80064 : if (domain != AF_UNIX) { /* not on UNIX sockets */
618 8950 : int tos = IPTOS_THROUGHPUT;
619 :
620 8950 : (void) setsockopt(sock, IPPROTO_IP, IP_TOS, (void *) &tos, sizeof(tos));
621 : }
622 : #endif
623 : #ifdef TCP_NODELAY
624 80064 : if (domain != AF_UNIX) { /* not on UNIX sockets */
625 8950 : int nodelay = 1;
626 :
627 8950 : (void) setsockopt(sock, IPPROTO_TCP, TCP_NODELAY, (void *) &nodelay, sizeof(nodelay));
628 : }
629 : #endif
630 : #ifdef HAVE_FCNTL
631 : {
632 80064 : int fl = fcntl(sock, F_GETFL);
633 :
634 80064 : fl &= ~O_NONBLOCK;
635 80064 : if (fcntl(sock, F_SETFL, fl) < 0) {
636 0 : mnstr_set_error_errno(s, MNSTR_OPEN_ERROR, "fcntl unset O_NONBLOCK failed");
637 0 : return s;
638 : }
639 : }
640 : #endif
641 :
642 : return s;
643 : }
644 :
645 : stream *
646 40032 : socket_rstream(SOCKET sock, const char *name)
647 : {
648 40032 : stream *s = NULL;
649 :
650 : #ifdef STREAM_DEBUG
651 : fprintf(stderr, "socket_rstream %zd %s\n", (ssize_t) sock, name);
652 : #endif
653 40032 : if ((s = socket_open(sock, name)) != NULL)
654 40032 : s->binary = true;
655 40032 : return s;
656 : }
657 :
658 : stream *
659 40032 : socket_wstream(SOCKET sock, const char *name)
660 : {
661 40032 : stream *s;
662 :
663 : #ifdef STREAM_DEBUG
664 : fprintf(stderr, "socket_wstream %zd %s\n", (ssize_t) sock, name);
665 : #endif
666 40032 : if ((s = socket_open(sock, name)) == NULL)
667 : return NULL;
668 40032 : s->readonly = false;
669 40032 : s->binary = true;
670 40032 : return s;
671 : }
|