LCOV - code coverage report
Current view: top level - common/stream - socket_stream.c (source / functions) Hit Total Coverage
Test: coverage.info Lines: 117 141 83.0 %
Date: 2024-04-25 20:03:45 Functions: 8 8 100.0 %

          Line data    Source code
       1             : /*
       2             :  * SPDX-License-Identifier: MPL-2.0
       3             :  *
       4             :  * This Source Code Form is subject to the terms of the Mozilla Public
       5             :  * License, v. 2.0.  If a copy of the MPL was not distributed with this
       6             :  * file, You can obtain one at http://mozilla.org/MPL/2.0/.
       7             :  *
       8             :  * Copyright 2024 MonetDB Foundation;
       9             :  * Copyright August 2008 - 2023 MonetDB B.V.;
      10             :  * Copyright 1997 - July 2008 CWI.
      11             :  */
      12             : 
      13             : /* Generic stream handling code such as init and close */
      14             : 
      15             : #include "monetdb_config.h"
      16             : #include "stream.h"
      17             : #include "stream_internal.h"
      18             : #ifdef HAVE_SYS_TIME_H
      19             : #include <sys/time.h>
      20             : #endif
      21             : 
      22             : 
      23             : /* ------------------------------------------------------------------ */
      24             : /* streams working on a socket */
      25             : 
      26             : static ssize_t
      27     1888899 : socket_write(stream *restrict s, const void *restrict buf, size_t elmsize, size_t cnt)
      28             : {
      29     1888899 :         size_t size = elmsize * cnt, res = 0;
      30             : #ifdef NATIVE_WIN32
      31             :         int nr = 0;
      32             : #else
      33     1888899 :         ssize_t nr = 0;
      34             : #endif
      35             : 
      36     1888899 :         if (s->errkind != MNSTR_NO__ERROR)
      37             :                 return -1;
      38             : 
      39     1888899 :         if (size == 0 || elmsize == 0)
      40           0 :                 return (ssize_t) cnt;
      41             : 
      42     1888899 :         errno = 0;
      43     3777631 :         while (res < size &&
      44             :                (
      45             : #ifdef NATIVE_WIN32
      46             :                        /* send works on int, make sure the argument fits */
      47             :                        ((nr = send(s->stream_data.s, (const char *) buf + res, (int) min(size - res, 1 << 16), 0)) > 0)
      48             : #else
      49     1888899 :                        ((nr = write(s->stream_data.s, (const char *) buf + res, size - res)) > 0)
      50             : #endif
      51         172 :                        || (nr < 0 && /* syscall failed */
      52         172 :                            s->timeout > 0 &&      /* potentially timeout */
      53             : #ifdef _MSC_VER
      54             :                            WSAGetLastError() == WSAEWOULDBLOCK &&
      55             : #else
      56           0 :                            (errno == EAGAIN
      57             : #if EAGAIN != EWOULDBLOCK
      58             :                             || errno == EWOULDBLOCK
      59             : #endif
      60           0 :                                    ) && /* it was! */
      61             : #endif
      62           0 :                            s->timeout_func != NULL &&        /* callback function exists */
      63           0 :                            !s->timeout_func(s->timeout_data))     /* callback says don't stop */
      64         172 :                        ||(nr < 0 &&
      65             : #ifdef _MSC_VER
      66             :                           WSAGetLastError() == WSAEINTR
      67             : #else
      68         172 :                           errno == EINTR
      69             : #endif
      70             :                                ))       /* interrupted */
      71             :                 ) {
      72     1888732 :                 errno = 0;
      73             : #ifdef _MSC_VER
      74             :                 WSASetLastError(0);
      75             : #endif
      76     1888732 :                 if (nr > 0)
      77     1888732 :                         res += (size_t) nr;
      78             :         }
      79     1888904 :         if (res >= elmsize)
      80     1888732 :                 return (ssize_t) (res / elmsize);
      81         172 :         if (nr < 0) {
      82         172 :                 if (s->timeout > 0 &&
      83             : #ifdef _MSC_VER
      84             :                     WSAGetLastError() == WSAEWOULDBLOCK
      85             : #else
      86           0 :                     (errno == EAGAIN
      87             : #if EAGAIN != EWOULDBLOCK
      88             :                      || errno == EWOULDBLOCK
      89             : #endif
      90             :                             )
      91             : #endif
      92             :                         )
      93           0 :                         mnstr_set_error(s, MNSTR_TIMEOUT, NULL);
      94             :                 else
      95         172 :                         mnstr_set_error_errno(s, MNSTR_WRITE_ERROR, "socket write");
      96         172 :                 return -1;
      97             :         }
      98             :         return 0;
      99             : }
     100             : 
     101             : static ssize_t
     102     7149217 : socket_read(stream *restrict s, void *restrict buf, size_t elmsize, size_t cnt)
     103             : {
     104             : #ifdef _MSC_VER
     105             :         int nr = 0;
     106             : #else
     107     7149217 :         ssize_t nr = 0;
     108             : #endif
     109     7149217 :         size_t size = elmsize * cnt;
     110             : 
     111     7149217 :         if (s->errkind != MNSTR_NO__ERROR)
     112             :                 return -1;
     113     7149217 :         if (size == 0)
     114             :                 return 0;
     115             : 
     116             : #ifdef _MSC_VER
     117             :         /* recv only takes an int parameter, and read does not accept
     118             :          * sockets */
     119             :         if (size > INT_MAX)
     120             :                 size = elmsize * (INT_MAX / elmsize);
     121             : #endif
     122     7153360 :         for (;;) {
     123     7153360 :                 if (s->timeout) {
     124     6143076 :                         int ret;
     125             : #ifdef HAVE_POLL
     126     6143076 :                         struct pollfd pfd;
     127             : 
     128     6143076 :                         pfd = (struct pollfd) {.fd = s->stream_data.s,
     129             :                                                .events = POLLIN};
     130             : 
     131     6143076 :                         ret = poll(&pfd, 1, (int) s->timeout);
     132     6143087 :                         if (ret == -1 && errno == EINTR)
     133        4143 :                                 continue;
     134     6143087 :                         if (ret == -1 || (pfd.revents & POLLERR)) {
     135         124 :                                 mnstr_set_error_errno(s, MNSTR_READ_ERROR, "poll error");
     136         142 :                                 return -1;
     137             :                         }
     138             : #else
     139             :                         struct timeval tv;
     140             :                         fd_set fds;
     141             : 
     142             :                         errno = 0;
     143             : #ifdef _MSC_VER
     144             :                         WSASetLastError(0);
     145             : #endif
     146             :                         FD_ZERO(&fds);
     147             :                         FD_SET(s->stream_data.s, &fds);
     148             :                         tv.tv_sec = s->timeout / 1000;
     149             :                         tv.tv_usec = (s->timeout % 1000) * 1000;
     150             :                         ret = select(
     151             : #ifdef _MSC_VER
     152             :                                 0,      /* ignored on Windows */
     153             : #else
     154             :                                 s->stream_data.s + 1,
     155             : #endif
     156             :                                 &fds, NULL, NULL, &tv);
     157             :                         if (ret == SOCKET_ERROR) {
     158             :                                 mnstr_set_error_errno(s, MNSTR_READ_ERROR, "select");
     159             :                                 return -1;
     160             :                         }
     161             : #endif
     162     6142963 :                         if (ret == 0) {
     163        4161 :                                 if (s->timeout_func == NULL || s->timeout_func(s->timeout_data)) {
     164          18 :                                         mnstr_set_error(s, MNSTR_TIMEOUT, NULL);
     165          18 :                                         return -1;
     166             :                                 }
     167        4143 :                                 continue;
     168             :                         }
     169     6138802 :                         assert(ret == 1);
     170             : #ifdef HAVE_POLL
     171     6138802 :                         assert(pfd.revents & (POLLIN|POLLHUP));
     172             : #else
     173             :                         assert(FD_ISSET(s->stream_data.s, &fds));
     174             : #endif
     175             :                 }
     176             : #ifdef _MSC_VER
     177             :                 nr = recv(s->stream_data.s, buf, (int) size, 0);
     178             :                 if (nr == SOCKET_ERROR) {
     179             :                         mnstr_set_error_errno(s, MNSTR_READ_ERROR, "recv");
     180             :                         return -1;
     181             :                 }
     182             : #else
     183     7149086 :                 nr = read(s->stream_data.s, buf, size);
     184     7149104 :                 if (nr == -1 && errno == EINTR)
     185           0 :                         continue;
     186     7149104 :                 if (nr == -1) {
     187           1 :                         mnstr_set_error_errno(s, MNSTR_READ_ERROR, NULL);
     188           1 :                         return -1;
     189             :                 }
     190             : #endif
     191     7149103 :                 break;
     192             :         }
     193     7149103 :         if (nr == 0) {
     194       38395 :                 s->eof = true;
     195       38395 :                 return 0;       /* end of file */
     196             :         }
     197     7110708 :         if (elmsize > 1) {
     198      950821 :                 while ((size_t) nr % elmsize != 0) {
     199             :                         /* if elmsize > 1, we really expect that "the
     200             :                          * other side" wrote complete items in a
     201             :                          * single system call, so we expect to at
     202             :                          * least receive complete items, and hence we
     203             :                          * continue reading until we did in fact
     204             :                          * receive an integral number of complete
     205             :                          * items, ignoring any timeouts (but not real
     206             :                          * errors) (note that recursion is limited
     207             :                          * since we don't propagate the element size
     208             :                          * to the recursive call) */
     209           6 :                         ssize_t n;
     210           6 :                         n = socket_read(s, (char *) buf + nr, 1, size - (size_t) nr);
     211           0 :                         if (n < 0) {
     212           0 :                                 if (s->errkind == MNSTR_NO__ERROR)
     213           0 :                                         mnstr_set_error(s, MNSTR_READ_ERROR, "socket_read failed");
     214           0 :                                 return -1;
     215             :                         }
     216           0 :                         if (n == 0)     /* unexpected end of file */
     217             :                                 break;
     218           0 :                         nr +=
     219             : #ifdef _MSC_VER
     220             :                                 (int)
     221             : #endif
     222             :                                 n;
     223             :                 }
     224             :         }
     225     7110702 :         return nr / (ssize_t) elmsize;
     226             : }
     227             : 
     228             : static void
     229       79498 : socket_close(stream *s)
     230             : {
     231       79498 :         SOCKET fd = s->stream_data.s;
     232             : 
     233       79498 :         if (fd != INVALID_SOCKET) {
     234             :                 /* Related read/write (in/out, from/to) streams
     235             :                  * share a single socket which is not dup'ed (anymore)
     236             :                  * as Windows' dup doesn't work on sockets;
     237             :                  * hence, only one of the streams must/may close that
     238             :                  * socket; we choose to let the read socket do the
     239             :                  * job, since in mapi.c it may happen that the read
     240             :                  * stream is closed before the write stream was even
     241             :                  * created.
     242             :                  */
     243       79498 :                 if (s->readonly) {
     244             : #ifdef HAVE_SHUTDOWN
     245       39749 :                         shutdown(fd, SHUT_RDWR);
     246             : #endif
     247       39749 :                         closesocket(fd);
     248             :                 }
     249             :         }
     250       79498 :         s->stream_data.s = INVALID_SOCKET;
     251       79498 : }
     252             : 
     253             : static void
     254       38205 : socket_update_timeout(stream *s)
     255             : {
     256       38205 :         SOCKET fd = s->stream_data.s;
     257       38205 :         struct timeval tv;
     258             : 
     259       38205 :         if (fd == INVALID_SOCKET)
     260           0 :                 return;
     261       38205 :         tv.tv_sec = s->timeout / 1000;
     262       38205 :         tv.tv_usec = (s->timeout % 1000) * 1000;
     263             :         /* cast to char * for Windows, no harm on "normal" systems */
     264       38205 :         if (!s->readonly)
     265           6 :                 (void) setsockopt(fd, SOL_SOCKET, SO_SNDTIMEO, (char *) &tv, (socklen_t) sizeof(tv));
     266             : }
     267             : 
     268             : #ifndef MSG_DONTWAIT
     269             : #define MSG_DONTWAIT 0
     270             : #endif
     271             : 
     272             : static int
     273     7771648 : socket_isalive(const stream *s)
     274             : {
     275     7771648 :         SOCKET fd = s->stream_data.s;
     276             : #ifdef HAVE_POLL
     277     7771648 :         struct pollfd pfd;
     278     7771648 :         int ret;
     279     7771648 :         pfd = (struct pollfd){.fd = fd};
     280     7771648 :         if ((ret = poll(&pfd, 1, 0)) == 0)
     281             :                 return 1;
     282           0 :         if (ret == -1 && errno == EINTR)
     283           0 :                 return socket_isalive(s);
     284           0 :         if (ret < 0 || pfd.revents & (POLLERR | POLLHUP))
     285             :                 return 0;
     286           0 :         assert(0);              /* unexpected revents value */
     287             :         return 0;
     288             : #else
     289             :         fd_set fds;
     290             :         struct timeval t;
     291             :         char buffer[32];
     292             : 
     293             :         t.tv_sec = 0;
     294             :         t.tv_usec = 0;
     295             :         FD_ZERO(&fds);
     296             :         FD_SET(fd, &fds);
     297             :         return select(
     298             : #ifdef _MSC_VER
     299             :                 0,      /* ignored on Windows */
     300             : #else
     301             :                 fd + 1,
     302             : #endif
     303             :                 &fds, NULL, NULL, &t) <= 0 ||
     304             :                 recv(fd, buffer, sizeof(buffer), MSG_PEEK | MSG_DONTWAIT) != 0;
     305             : #endif
     306             : }
     307             : 
     308             : static stream *
     309       79512 : socket_open(SOCKET sock, const char *name)
     310             : {
     311       79512 :         stream *s;
     312       79512 :         int domain = 0;
     313             : 
     314       79512 :         if (sock == INVALID_SOCKET) {
     315           0 :                 mnstr_set_open_error(name, 0, "invalid socket");
     316           0 :                 return NULL;
     317             :         }
     318       79512 :         if ((s = create_stream(name)) == NULL)
     319             :                 return NULL;
     320       79512 :         s->read = socket_read;
     321       79512 :         s->write = socket_write;
     322       79512 :         s->close = socket_close;
     323       79512 :         s->stream_data.s = sock;
     324       79512 :         s->update_timeout = socket_update_timeout;
     325       79512 :         s->isalive = socket_isalive;
     326             : 
     327       79512 :         errno = 0;
     328             : #ifdef _MSC_VER
     329             :         WSASetLastError(0);
     330             : #endif
     331             : #if defined(SO_DOMAIN)
     332             :         {
     333       79512 :                 socklen_t len = (socklen_t) sizeof(domain);
     334       79512 :                 if (getsockopt(sock, SOL_SOCKET, SO_DOMAIN, (void *) &domain, &len) == SOCKET_ERROR)
     335           0 :                         domain = AF_INET;       /* give it a value if call fails */
     336             :         }
     337             : #endif
     338             : #if defined(SO_KEEPALIVE) && !defined(WIN32)
     339       79512 :         if (domain != PF_UNIX) {        /* not on UNIX sockets */
     340        8598 :                 int opt = 1;
     341        8598 :                 (void) setsockopt(sock, SOL_SOCKET, SO_KEEPALIVE, (void *) &opt, sizeof(opt));
     342             :         }
     343             : #endif
     344             : #if defined(IPTOS_THROUGHPUT) && !defined(WIN32)
     345       79512 :         if (domain != PF_UNIX) {        /* not on UNIX sockets */
     346        8598 :                 int tos = IPTOS_THROUGHPUT;
     347             : 
     348        8598 :                 (void) setsockopt(sock, IPPROTO_IP, IP_TOS, (void *) &tos, sizeof(tos));
     349             :         }
     350             : #endif
     351             : #ifdef TCP_NODELAY
     352       79512 :         if (domain != PF_UNIX) {        /* not on UNIX sockets */
     353        8598 :                 int nodelay = 1;
     354             : 
     355        8598 :                 (void) setsockopt(sock, IPPROTO_TCP, TCP_NODELAY, (void *) &nodelay, sizeof(nodelay));
     356             :         }
     357             : #endif
     358             : #ifdef HAVE_FCNTL
     359             :         {
     360       79512 :                 int fl = fcntl(sock, F_GETFL);
     361             : 
     362       79512 :                 fl &= ~O_NONBLOCK;
     363       79512 :                 if (fcntl(sock, F_SETFL, fl) < 0) {
     364           0 :                         mnstr_set_error_errno(s, MNSTR_OPEN_ERROR, "fcntl unset O_NONBLOCK failed");
     365           0 :                         return s;
     366             :                 }
     367             :         }
     368             : #endif
     369             : 
     370             :         return s;
     371             : }
     372             : 
     373             : stream *
     374       39756 : socket_rstream(SOCKET sock, const char *name)
     375             : {
     376       39756 :         stream *s = NULL;
     377             : 
     378             : #ifdef STREAM_DEBUG
     379             :         fprintf(stderr, "socket_rstream %zd %s\n", (ssize_t) sock, name);
     380             : #endif
     381       39756 :         if ((s = socket_open(sock, name)) != NULL)
     382       39756 :                 s->binary = true;
     383       39756 :         return s;
     384             : }
     385             : 
     386             : stream *
     387       39756 : socket_wstream(SOCKET sock, const char *name)
     388             : {
     389       39756 :         stream *s;
     390             : 
     391             : #ifdef STREAM_DEBUG
     392             :         fprintf(stderr, "socket_wstream %zd %s\n", (ssize_t) sock, name);
     393             : #endif
     394       39756 :         if ((s = socket_open(sock, name)) == NULL)
     395             :                 return NULL;
     396       39756 :         s->readonly = false;
     397       39756 :         s->binary = true;
     398       39756 :         return s;
     399             : }

Generated by: LCOV version 1.14