LCOV - code coverage report
Current view: top level - clients/mapilib - connect.c (source / functions) Hit Total Coverage
Test: coverage.info Lines: 289 449 64.4 %
Date: 2025-03-25 20:06:35 Functions: 10 10 100.0 %

          Line data    Source code
       1             : /*
       2             :  * SPDX-License-Identifier: MPL-2.0
       3             :  *
       4             :  * This Source Code Form is subject to the terms of the Mozilla Public
       5             :  * License, v. 2.0.  If a copy of the MPL was not distributed with this
       6             :  * file, You can obtain one at http://mozilla.org/MPL/2.0/.
       7             :  *
       8             :  * Copyright 2024, 2025 MonetDB Foundation;
       9             :  * Copyright August 2008 - 2023 MonetDB B.V.;
      10             :  * Copyright 1997 - July 2008 CWI.
      11             :  */
      12             : 
      13             : #include "monetdb_config.h"
      14             : #include "stream.h"           /* include before mapi.h */
      15             : #include "stream_socket.h"
      16             : #include "mapi.h"
      17             : #include "mapi_prompt.h"
      18             : #include "mcrypt.h"
      19             : #include "matomic.h"
      20             : #include "mstring.h"
      21             : #include "mutils.h"
      22             : 
      23             : #include "mapi_intern.h"
      24             : 
      25             : #ifdef HAVE_SYS_SOCKET_H
      26             : # include <arpa/inet.h>                   /* addr_in */
      27             : # include <unistd.h>                      /* gethostname() */
      28             : #else /* UNIX specific */
      29             : #ifdef HAVE_WINSOCK_H                   /* Windows specific */
      30             : # include <winsock.h>
      31             : #endif
      32             : #endif
      33             : 
      34             : 
      35             : #ifdef HAVE_SYS_UN_H
      36             : #define DO_UNIX_DOMAIN (1)
      37             : #else
      38             : #define DO_UNIX_DOMAIN (0)
      39             : #endif
      40             : 
      41             : #ifdef _MSC_VER
      42             : #define SOCKET_STRERROR()       wsaerror(WSAGetLastError())
      43             : #else
      44             : #define SOCKET_STRERROR()       strerror(errno)
      45             : #endif
      46             : 
      47             : 
      48             : static MapiMsg scan_sockets(Mapi mid);
      49             : static MapiMsg connect_socket(Mapi mid);
      50             : static MapiMsg connect_socket_tcp(Mapi mid);
      51             : static SOCKET connect_socket_tcp_addr(Mapi mid, struct addrinfo *addr);
      52             : static MapiMsg mapi_handshake(Mapi mid);
      53             : 
      54             : #ifndef HAVE_OPENSSL
      55             : // The real implementation is in connect_openssl.c.
      56             : MapiMsg
      57             : wrap_tls(Mapi mid, SOCKET sock)
      58             : {
      59             :         closesocket(sock);
      60             :         return mapi_setError(mid, "Cannot connect to monetdbs://, not built with OpenSSL support", __func__, MERROR);
      61             : }
      62             : #endif // HAVE_OPENSSL
      63             : 
      64             : #ifndef HAVE_SYS_UN_H
      65             : 
      66             : MapiMsg
      67             : connect_socket_unix(Mapi mid)
      68             : {
      69             :         return mapi_setError(mid, "Unix domain sockets not supported", __func__, MERROR);
      70             : }
      71             : 
      72             : static MapiMsg
      73             : scan_unix_sockets(Mapi mid)
      74             : {
      75             :         return mapi_setError(mid, "Unix domain sockets not supported", __func__, MERROR);
      76             : }
      77             : 
      78             : #endif
      79             : 
      80             : 
      81             : 
      82             : 
      83             : /* (Re-)establish a connection with the server. */
      84             : MapiMsg
      85        1477 : mapi_reconnect(Mapi mid)
      86             : {
      87        1477 :         const char *err = msettings_validate(mid->settings);
      88        1477 :         if (err) {
      89           0 :                 mapi_setError(mid, err, __func__, MERROR);
      90           0 :                 return MERROR;
      91             :         }
      92             : 
      93             :         // If neither host nor port are given, scan the Unix domain sockets in
      94             :         // /tmp and see if any of them serve this database.
      95             :         // Otherwise, just try to connect to what was given.
      96        1477 :         if (msettings_connect_scan(mid->settings))
      97           2 :                 return scan_sockets(mid);
      98             :         else
      99        1475 :                 return establish_connection(mid);
     100             : }
     101             : 
     102             : static MapiMsg
     103           2 : scan_sockets(Mapi mid)
     104             : {
     105           2 :         msettings_error errmsg;
     106             : 
     107           2 :         if (scan_unix_sockets(mid) == MOK)
     108             :                 return MOK;
     109             : 
     110           2 :         errmsg = msetting_set_string(mid->settings, MP_HOST, "localhost");
     111           2 :         if (!errmsg)
     112           2 :                 errmsg = msettings_validate(mid->settings);
     113           2 :         if (errmsg) {
     114           0 :                 MapiMsg err = mapi_setError(mid, errmsg, __func__, MERROR);
     115           0 :                 return err;
     116             :         }
     117           2 :         return establish_connection(mid);
     118             : }
     119             : 
     120             : /* (Re-)establish a connection with the server. */
     121             : MapiMsg
     122        1477 : establish_connection(Mapi mid)
     123             : {
     124        1477 :         if (mid->connected) {
     125           1 :                 mapi_log_record(mid, "CONN", "Found leftover open connection");
     126           1 :                 close_connection(mid);
     127             :         }
     128             : 
     129             :         MapiMsg msg = MREDIRECT;
     130        2939 :         while (msg == MREDIRECT) {
     131             :                 // Generally at this point we need to set up a new TCP or Unix
     132             :                 // domain connection.
     133             :                 //
     134             :                 // The only exception is if mapi_handshake() below has decided
     135             :                 // that the handshake must be restarted on the existing
     136             :                 // connection.
     137        1477 :                 if (!mid->connected) {
     138        1477 :                         msg = connect_socket(mid);
     139        1477 :                         if (msg != MOK)
     140          15 :                                 return msg;
     141             :                 }
     142        1462 :                 msg = mapi_handshake(mid);
     143             :         }
     144             : 
     145             :         // Switch from MP_CONNECT_TIMEOUT to MP_REPLY_TIMEOUT
     146        1462 :         if (msg == MOK) {
     147        1449 :                 long connect_timeout = msetting_long(mid->settings, MP_CONNECT_TIMEOUT);
     148        1449 :                 long reply_timeout = msetting_long(mid->settings, MP_REPLY_TIMEOUT);
     149        1449 :                 if (connect_timeout > 0 || reply_timeout > 0) {
     150           0 :                         if (reply_timeout < 0)
     151             :                                 reply_timeout = 0;
     152           0 :                         msg = mapi_timeout(mid, reply_timeout);
     153             :                 }
     154             :         }
     155             : 
     156             :         return msg;
     157             : }
     158             : 
     159             : static MapiMsg
     160        1477 : connect_socket(Mapi mid)
     161             : {
     162        1477 :         assert(!mid->connected);
     163        1477 :         const char *sockname = msettings_connect_unix(mid->settings);
     164        1477 :         const char *tcp_host = msettings_connect_tcp(mid->settings);
     165        1477 :         long timeout = msetting_long(mid->settings, MP_CONNECT_TIMEOUT);
     166             : 
     167        1477 :         assert(*sockname || *tcp_host);
     168        2939 :         do {
     169        1477 :                 if (*sockname && connect_socket_unix(mid) == MOK)
     170             :                         break;
     171        1409 :                 if (*tcp_host && connect_socket_tcp(mid) == MOK)
     172             :                         break;
     173          15 :                 assert(mid->error == MERROR);
     174          15 :                 mid->error = MERROR; // in case assert above was not enabled
     175          15 :                 return mid->error;
     176             :         } while (0);
     177             : 
     178             :         // the socket code may have set SO_SNDTIMEO and SO_RCVTIMEO but
     179             :         // the mapi layer doesn't know this yet.
     180        1462 :         if (timeout > 0)
     181           0 :                 mapi_timeout(mid, timeout);
     182             : 
     183        1462 :         mid->connected = true;
     184        1462 :         return MOK;
     185             : }
     186             : 
     187             : MapiMsg
     188        1457 : wrap_socket(Mapi mid, SOCKET sock)
     189             : {
     190             :         // do not use check_stream here yet because the socket is not yet in 'mid'
     191        1457 :         stream *broken_stream = NULL;
     192        1457 :         MapiMsg msg;
     193        1457 :         stream *rstream = NULL;
     194        1457 :         stream *wstream = NULL;
     195             : 
     196        1457 :         wstream = socket_wstream(sock, "Mapi client write");
     197        1457 :         if (wstream == NULL || mnstr_errnr(wstream) != MNSTR_NO__ERROR) {
     198           0 :                 broken_stream = wstream;
     199           0 :                 goto bailout;
     200             :         }
     201             : 
     202        1457 :         rstream = socket_rstream(sock, "Mapi client write");
     203        1457 :         if (rstream == NULL || mnstr_errnr(rstream) != MNSTR_NO__ERROR) {
     204           0 :                 broken_stream = rstream;
     205           0 :                 goto bailout;
     206             :         }
     207             : 
     208        1457 :         msg = mapi_wrap_streams(mid, rstream, wstream);
     209        1457 :         if (msg != MOK)
     210           0 :                 goto bailout;
     211             :         return MOK;
     212             : 
     213           0 : bailout:
     214           0 :         if (rstream)
     215           0 :                 mnstr_destroy(rstream);
     216           0 :         if (wstream)
     217           0 :                 mnstr_destroy(wstream);
     218           0 :         closesocket(sock);
     219           0 :         if (broken_stream) {
     220           0 :                 char *error_message = "create stream from socket";
     221             :                 // malloc failure is the only way these calls could have failed
     222           0 :                 return mapi_printError(mid, __func__, MERROR, "%s: %s", error_message, mnstr_peek_error(broken_stream));
     223             :         } else {
     224             :                 return MERROR;
     225             :         }
     226             : }
     227             : 
     228             : static MapiMsg
     229        1409 : connect_socket_tcp(Mapi mid)
     230             : {
     231        1409 :         int ret;
     232             : 
     233        1409 :         bool use_tls = msetting_bool(mid->settings, MP_TLS);
     234        1409 :         const char *host = msettings_connect_tcp(mid->settings);
     235        1409 :         int port = msettings_connect_port(mid->settings);
     236             : 
     237        1409 :         assert(host);
     238        1409 :         char portbuf[10];
     239        1409 :         snprintf(portbuf, sizeof(portbuf), "%d", port);
     240             : 
     241        1409 :         mapi_log_record(mid, "CONN", "Connecting to %s:%d", host, port);
     242             : 
     243        1409 :         struct addrinfo hints = (struct addrinfo) {
     244             :                 .ai_family = AF_UNSPEC,
     245             :                 .ai_socktype = SOCK_STREAM,
     246             :                 .ai_protocol = IPPROTO_TCP,
     247             :         };
     248        1409 :         struct addrinfo *addresses;
     249        1409 :         ret = getaddrinfo(host, portbuf, &hints, &addresses);
     250        1409 :         if (ret != 0) {
     251           0 :                 return mapi_printError(
     252             :                         mid, __func__, MERROR,
     253             :                         "getaddrinfo %s:%s failed: %s", host, portbuf, gai_strerror(ret));
     254             :         }
     255        1409 :         if (addresses == NULL) {
     256           0 :                 return mapi_printError(
     257             :                         mid, __func__, MERROR,
     258             :                         "getaddrinfo return 0 addresses");
     259             :         }
     260             : 
     261             :         assert(addresses);
     262             :         SOCKET s;
     263        1423 :         for (struct addrinfo *addr = addresses; addr; addr = addr->ai_next) {
     264        1416 :                 s = connect_socket_tcp_addr(mid, addr);
     265        1416 :                 if (s != INVALID_SOCKET)
     266             :                         break;
     267             :         }
     268        1409 :         freeaddrinfo(addresses);
     269        1409 :         if (s == INVALID_SOCKET) {
     270             :                 // connect_socket_tcp_addr has already set an error message
     271             :                 return MERROR;
     272             :         }
     273             : 
     274             :         /* compare our own address with that of our peer and
     275             :          * if they are the same, we were connected to our own
     276             :          * socket, so then we can't use this connection */
     277        1402 :         union {
     278             :                 struct sockaddr_storage ss;
     279             :                 struct sockaddr_in i4;
     280             :                 struct sockaddr_in6 i6;
     281             :         } myaddr, praddr;
     282        1402 :         socklen_t myaddrlen, praddrlen;
     283        1402 :         myaddrlen = (socklen_t) sizeof(myaddr.ss);
     284        1402 :         praddrlen = (socklen_t) sizeof(praddr.ss);
     285        2804 :         if (getsockname(s, (struct sockaddr *) &myaddr.ss, &myaddrlen) == 0 &&
     286        1402 :                 getpeername(s, (struct sockaddr *) &praddr.ss, &praddrlen) == 0 &&
     287        2804 :                 myaddr.ss.ss_family == praddr.ss.ss_family &&
     288             :                 (myaddr.ss.ss_family == AF_INET
     289        1402 :                 ? myaddr.i4.sin_port == praddr.i4.sin_port
     290           0 :                 : myaddr.i6.sin6_port == praddr.i6.sin6_port) &&
     291             :                 (myaddr.ss.ss_family == AF_INET
     292           0 :                 ? myaddr.i4.sin_addr.s_addr == praddr.i4.sin_addr.s_addr
     293           0 :                 : memcmp(myaddr.i6.sin6_addr.s6_addr,
     294             :                         praddr.i6.sin6_addr.s6_addr,
     295             :                         sizeof(praddr.i6.sin6_addr.s6_addr)) == 0)) {
     296           0 :                 closesocket(s);
     297           0 :                 return mapi_setError(mid, "connected to self", __func__, MERROR);
     298             :         }
     299             : 
     300        1402 :         mapi_log_record(mid, "CONN", "Network connection established");
     301        1402 :         MapiMsg msg = use_tls ? wrap_tls(mid, s) : wrap_socket(mid, s);
     302        1402 :         if (msg != MOK)
     303             :                 return msg;
     304             : 
     305             :         return msg;
     306             : }
     307             : 
     308             : static SOCKET
     309        1416 : connect_socket_tcp_addr(Mapi mid, struct addrinfo *info)
     310             : {
     311        1416 :         long timeout = msetting_long(mid->settings, MP_CONNECT_TIMEOUT);
     312             : 
     313        1416 :         if (mid->tracelog) {
     314          15 :                 char addrbuf[100] = {0};
     315          15 :                 const char *addrtext;
     316          15 :                 int port;
     317          15 :                 if (info->ai_family == AF_INET) {
     318          15 :                         struct sockaddr_in *addr4 = (struct sockaddr_in*)info->ai_addr;
     319          15 :                         port = ntohs(addr4->sin_port);
     320          15 :                         void *addr = &addr4->sin_addr;
     321          15 :                         addrtext = inet_ntop(info->ai_family, addr, addrbuf, sizeof(addrbuf));
     322           0 :                 } else if (info->ai_family == AF_INET6) {
     323           0 :                         struct sockaddr_in6 *addr6 = (struct sockaddr_in6*)info->ai_addr;
     324           0 :                         port = ntohs(addr6->sin6_port);
     325           0 :                         void *addr = &addr6->sin6_addr;
     326           0 :                         addrtext = inet_ntop(info->ai_family, addr, addrbuf, sizeof(addrbuf));
     327             :                 } else {
     328             :                         port = -1;
     329             :                         addrtext = NULL;
     330             :                 }
     331          15 :                 mapi_log_record(mid, "CONN", "Trying IP %s port %d with timeout %ld", addrtext ? addrtext : "<UNKNOWN>", port, timeout);
     332             :         }
     333             : 
     334             : 
     335        1416 :         int socktype = info->ai_socktype;
     336             : #ifdef SOCK_CLOEXEC
     337        1416 :         socktype |= SOCK_CLOEXEC;
     338             : #endif
     339             : 
     340        1416 :         SOCKET s =  socket(info->ai_family, socktype, info->ai_protocol);
     341        1416 :         if (s == INVALID_SOCKET) {
     342          14 :                 mapi_printError(
     343             :                         mid, __func__, MERROR,
     344           7 :                         "could not create TCP socket: %s", SOCKET_STRERROR());
     345           7 :                 return INVALID_SOCKET;
     346             :         }
     347             : 
     348             : #if !defined(SOCK_CLOEXEC) && defined(HAVE_FCNTL)
     349             :         (void) fcntl(s, F_SETFD, FD_CLOEXEC);
     350             : #endif
     351             : 
     352        1409 :         if (timeout > 0) {
     353           0 :                 struct timeval tv = {
     354           0 :                         .tv_sec = timeout / 1000,
     355           0 :                         .tv_usec = timeout % 1000,
     356             :                 };
     357             :                 /* cast to char * for Windows, no harm on "normal" systems */
     358           0 :                 if (
     359           0 :                         setsockopt(s, SOL_SOCKET, SO_SNDTIMEO, (char*)&tv, sizeof(tv)) == SOCKET_ERROR
     360           0 :                         || setsockopt(s, SOL_SOCKET, SO_RCVTIMEO, (char*)&tv, sizeof(tv)) == SOCKET_ERROR
     361             :                 ) {
     362           0 :                         closesocket(s);
     363           0 :                         return mapi_printError(
     364             :                                 mid, __func__, MERROR,
     365           0 :                                 "could not set connect timeout: %s", strerror(errno));
     366             :                 }
     367             :         }
     368             : 
     369             :         // cast addrlen to int to satisfy Windows.
     370        1409 :         if (connect(s, info->ai_addr, (int)info->ai_addrlen) == SOCKET_ERROR) {
     371           7 :                 mapi_printError(
     372             :                         mid, __func__, MERROR,
     373           7 :                         "could not connect: %s", SOCKET_STRERROR());
     374           7 :                 closesocket(s);
     375           7 :                 return INVALID_SOCKET;
     376             :         }
     377             : 
     378             :         return s;
     379             : }
     380             : 
     381             : static const char *
     382        1252 : base_name(const char *file)
     383             : {
     384        1252 :         char *p = strrchr(file, '/');
     385             : #ifdef _MSC_VER
     386             :         char *q = strrchr(file, '\\');
     387             :         if (q != NULL) {
     388             :                 if (p == NULL || p < q)
     389             :                         p = q;
     390             :         }
     391             : #endif
     392        1252 :         if (p)
     393        1252 :                 return p + 1;
     394             :         return file;
     395             : }
     396             : 
     397             : static void
     398        1255 : send_all_clientinfo(Mapi mid)
     399             : {
     400        1255 :         msettings *mp = mid->settings;
     401        1255 :         void *free_this = NULL;
     402        1255 :         if (!mid->clientinfo_supported)
     403           2 :                 return;
     404        1255 :         if (!msetting_bool(mp, MP_CLIENT_INFO))
     405             :                 return;
     406             : 
     407             : 
     408        1253 :         static char hostname[120] = { 0 };
     409        1253 :         if (hostname[0] == '\0') {
     410         181 :                 if (gethostname(hostname, sizeof(hostname)) != 0)
     411           0 :                         hostname[0] = '\0';
     412         181 :                 hostname[sizeof(hostname) - 1] = '\0';
     413             :         }
     414        1253 :         const char *application_name = msetting_string(mp, MP_CLIENT_APPLICATION);
     415        1253 :         if (!application_name[0]) {
     416        1252 :                 application_name = get_bin_path();
     417        1252 :                 if (application_name) {
     418        1252 :                         free_this = strdup(application_name);
     419        1252 :                         application_name = base_name(application_name);
     420             :                 }
     421             :         }
     422        1253 :         const char *client_remark = msetting_string(mp, MP_CLIENT_REMARK);
     423        1253 :         long pid = getpid();
     424             : 
     425        1253 :         char *buf = NULL;
     426        1253 :         size_t pos = 0, cap = 200;
     427             : 
     428        1253 :         if (hostname[0])
     429        1253 :                 reallocprintf(&buf, &pos, &cap, "ClientHostname=%s\n", hostname);
     430        1253 :         if (application_name && application_name[0])
     431        1253 :                 reallocprintf(&buf, &pos, &cap, "ApplicationName=%s\n", application_name);
     432        1253 :         reallocprintf(&buf, &pos, &cap, "ClientLibrary=");
     433        1253 :         if (mid->clientprefix)
     434         102 :                 reallocprintf(&buf, &pos, &cap, "%s / ", mid->clientprefix);
     435        1253 :         reallocprintf(&buf, &pos, &cap, "libmapi %s\n", MONETDB_VERSION);
     436        1253 :         if (client_remark[0])
     437           6 :                 reallocprintf(&buf, &pos, &cap, "ClientRemark=%s\n", client_remark);
     438        1253 :         if (pid > 0)
     439        1253 :                 reallocprintf(&buf, &pos, &cap, "ClientPid=%ld\n", pid);
     440             : 
     441        1253 :         if (pos > 1) {
     442        1253 :                 assert(buf[pos - 1] == '\n');
     443        1253 :                 pos--;
     444        1253 :                 buf[pos] = '\0';
     445             :         }
     446             : 
     447        1253 :         if (pos <= cap)
     448        1253 :                 mapi_Xcommand(mid, "clientinfo", buf);
     449             : 
     450        1253 :         free(buf);
     451        1253 :         free(free_this);
     452             : }
     453             : 
     454             : static MapiMsg
     455        1462 : mapi_handshake(Mapi mid)
     456             : {
     457        1462 :         char buf[BLOCK];
     458        1462 :         size_t len;
     459        1462 :         MapiHdl hdl;
     460             : 
     461        1462 :         const char *username = msetting_string(mid->settings, MP_USER);
     462        1462 :         const char *password = msetting_string(mid->settings, MP_PASSWORD);
     463             : 
     464             :         /* consume server challenge */
     465        1462 :         len = mnstr_read_block(mid->from, buf, 1, sizeof(buf));
     466        1462 :         check_stream(mid, mid->from, len, "Connection terminated while starting handshake", (mid->blk.eos = true, mid->error));
     467             : 
     468        1461 :         mapi_log_data(mid, "RECV HANDSHAKE", buf, len);
     469             : 
     470        1461 :         assert(len < sizeof(buf));
     471        1461 :         buf[len] = 0;
     472             : 
     473        1461 :         if (len == 0) {
     474           0 :                 mapi_setError(mid, "Challenge string is not valid, it is empty", __func__, MERROR);
     475           0 :                 return mid->error;
     476             :         }
     477             :         /* buf at this point looks like "challenge:servertype:protover[:.*]" */
     478             : 
     479        1461 :         char *strtok_state = NULL;
     480        1461 :         char *chal = strtok_r(buf, ":", &strtok_state);
     481        1461 :         if (chal == NULL) {
     482           0 :                 mapi_setError(mid, "Challenge string is not valid, challenge not found", __func__, MERROR);
     483           0 :                 close_connection(mid);
     484           0 :                 return mid->error;
     485             :         }
     486             : 
     487        1461 :         char *server = strtok_r(NULL, ":", &strtok_state);
     488        1461 :         if (server == NULL) {
     489           0 :                 mapi_setError(mid, "Challenge string is not valid, server not found", __func__, MERROR);
     490           0 :                 close_connection(mid);
     491           0 :                 return mid->error;
     492             :         }
     493             : 
     494        1461 :         char *protover = strtok_r(NULL, ":", &strtok_state);
     495        1461 :         if (protover == NULL) {
     496           0 :                 mapi_setError(mid, "Challenge string is not valid, protocol not found", __func__, MERROR);
     497           0 :                 close_connection(mid);
     498           0 :                 return mid->error;
     499             :         }
     500        1461 :         int pversion = atoi(protover);
     501        1461 :         if (pversion != 9) {
     502             :                 /* because the headers changed, and because it makes no sense to
     503             :                  * try and be backwards (or forwards) compatible, we bail out
     504             :                  * with a friendly message saying so */
     505           0 :                 snprintf(buf, sizeof(buf), "unsupported protocol version: %d, "
     506             :                          "this client only supports version 9", pversion);
     507           0 :                 mapi_setError(mid, buf, __func__, MERROR);
     508           0 :                 close_connection(mid);
     509           0 :                 return mid->error;
     510             :         }
     511             : 
     512        1461 :         char *hashes = strtok_r(NULL, ":", &strtok_state);
     513        1461 :         if (hashes == NULL) {
     514             :                 /* protocol violation, not enough fields */
     515           0 :                 mapi_setError(mid, "Not enough fields in challenge string", __func__, MERROR);
     516           0 :                 close_connection(mid);
     517           0 :                 return mid->error;
     518             :         }
     519        1461 :         char *algsv[] = {
     520             :                 "RIPEMD160",
     521             :                 "SHA512",
     522             :                 "SHA384",
     523             :                 "SHA256",
     524             :                 "SHA224",
     525             :                 "SHA1",
     526             :                 NULL
     527             :         };
     528        1461 :         char **algs = algsv;
     529             : 
     530             :         /* rBuCQ9WTn3:mserver:9:RIPEMD160,SHA256,SHA1,MD5:LIT:SHA1: */
     531             : 
     532        1461 :         if (!*username || !*password) {
     533           0 :                 mapi_setError(mid, "both username and password must be set",
     534             :                                 __func__, MERROR);
     535           0 :                 close_connection(mid);
     536           0 :                 return mid->error;
     537             :         }
     538             : 
     539             :         /* the database has sent a list of supported hashes to us, it's
     540             :                 * in the form of a comma separated list and in the variable
     541             :                 * rest.  We try to use the strongest algorithm. */
     542             : 
     543             : 
     544             :         /* in rest now should be the byte order of the server */
     545        1461 :         char *byteo = strtok_r(NULL, ":", &strtok_state);
     546             : 
     547             :         /* Proto v9 is like v8, but mandates that the password is a
     548             :                 * hash, that is salted like in v8.  The hash algorithm is
     549             :                 * specified in the 6th field.  If we don't support it, we
     550             :                 * can't login. */
     551        1461 :         char *serverhash = strtok_r(NULL, ":", &strtok_state);
     552             : 
     553        1461 :         char *handshake_options = strtok_r(NULL, ":", &strtok_state);
     554        1461 :         if (handshake_options) {
     555        1455 :                 if (sscanf(handshake_options, "sql=%d", &mid->handshake_options) != 1) {
     556           0 :                         mapi_setError(mid, "invalid handshake options",
     557             :                                         __func__, MERROR);
     558           0 :                         close_connection(mid);
     559           0 :                         return mid->error;
     560             :                 }
     561             :         }
     562             : 
     563             :         /* skip the binary option */
     564        1461 :         char *binary = strtok_r(NULL, ":", &strtok_state);
     565        1461 :         (void)binary;
     566             : 
     567        1461 :         char *oobintr = strtok_r(NULL, ":", &strtok_state);
     568        1461 :         if (oobintr) {
     569        1455 :                 if (strcmp(oobintr, "OOBINTR=1") == 0) {
     570        1455 :                         mid->oobintr = true;
     571             :                 }
     572             :         }
     573             : 
     574        1461 :         char *clientinfo = strtok_r(NULL, ":", &strtok_state);
     575        1461 :         if (clientinfo) {
     576        1455 :                 if (strcmp(clientinfo, "CLIENTINFO") == 0) {
     577        1455 :                         mid->clientinfo_supported = true;
     578             :                 }
     579             :         }
     580             : 
     581             :         /* hash password, if not already */
     582        1461 :         if (password[0] != '\1') {
     583         273 :                 char *pwdhash = NULL;
     584         273 :                 if (strcmp(serverhash, "RIPEMD160") == 0) {
     585           0 :                         pwdhash = mcrypt_RIPEMD160Sum(password,
     586             :                                                         strlen(password));
     587         273 :                 } else if (strcmp(serverhash, "SHA512") == 0) {
     588         273 :                         pwdhash = mcrypt_SHA512Sum(password,
     589             :                                                         strlen(password));
     590           0 :                 } else if (strcmp(serverhash, "SHA384") == 0) {
     591           0 :                         pwdhash = mcrypt_SHA384Sum(password,
     592             :                                                         strlen(password));
     593           0 :                 } else if (strcmp(serverhash, "SHA256") == 0) {
     594           0 :                         pwdhash = mcrypt_SHA256Sum(password,
     595             :                                                         strlen(password));
     596           0 :                 } else if (strcmp(serverhash, "SHA224") == 0) {
     597           0 :                         pwdhash = mcrypt_SHA224Sum(password,
     598             :                                                         strlen(password));
     599           0 :                 } else if (strcmp(serverhash, "SHA1") == 0) {
     600           0 :                         pwdhash = mcrypt_SHA1Sum(password,
     601             :                                                         strlen(password));
     602             :                 } else {
     603           0 :                         snprintf(buf, sizeof(buf), "server requires unknown hash '%.100s'",
     604             :                                         serverhash);
     605           0 :                         close_connection(mid);
     606           0 :                         return mapi_setError(mid, buf, __func__, MERROR);
     607             :                 }
     608             : 
     609         273 :                 if (pwdhash == NULL) {
     610           0 :                         snprintf(buf, sizeof(buf), "allocation failure or unknown hash '%.100s'",
     611             :                                         serverhash);
     612           0 :                         close_connection(mid);
     613           0 :                         return mapi_setError(mid, buf, __func__, MERROR);
     614             :                 }
     615             : 
     616         273 :                 char *replacement_password = malloc(1 + strlen(pwdhash) + 1);
     617         273 :                 if (replacement_password == NULL) {
     618           0 :                         free(pwdhash);
     619           0 :                         close_connection(mid);
     620           0 :                         return mapi_setError(mid, "malloc failed", __func__, MERROR);
     621             :                 }
     622         273 :                 sprintf(replacement_password, "\1%s", pwdhash);
     623         273 :                 free(pwdhash);
     624         273 :                 msettings_error errmsg = msetting_set_string(mid->settings, MP_PASSWORD, replacement_password);
     625         273 :                 free(replacement_password);
     626         273 :                 if (errmsg != NULL) {
     627           0 :                         close_connection(mid);
     628           0 :                         return mapi_setError(mid, "could not stow hashed password", __func__, MERROR);
     629             :                 }
     630             :         }
     631             : 
     632             : 
     633        1461 :         const char *pw = msetting_string(mid->settings, MP_PASSWORD);
     634        1461 :         assert(*pw == '\1');
     635        1461 :         pw++;
     636             : 
     637        1461 :         char *hash = NULL;
     638        1461 :         for (; *algs != NULL; algs++) {
     639             :                 /* TODO: make this actually obey the separation by
     640             :                         * commas, and only allow full matches */
     641        1461 :                 if (strstr(hashes, *algs) != NULL) {
     642        1461 :                         char *pwh = mcrypt_hashPassword(*algs, pw, chal);
     643        1461 :                         size_t len;
     644        1461 :                         if (pwh == NULL)
     645           0 :                                 continue;
     646        1461 :                         len = strlen(pwh) + strlen(*algs) + 3 /* {}\0 */;
     647        1461 :                         hash = malloc(len);
     648        1461 :                         if (hash == NULL) {
     649           0 :                                 close_connection(mid);
     650           0 :                                 free(pwh);
     651           0 :                                 return mapi_setError(mid, "malloc failure", __func__, MERROR);
     652             :                         }
     653        1461 :                         snprintf(hash, len, "{%s}%s", *algs, pwh);
     654        1461 :                         free(pwh);
     655        1461 :                         break;
     656             :                 }
     657             :         }
     658        1461 :         if (hash == NULL) {
     659             :                 /* the server doesn't support what we can */
     660           0 :                 snprintf(buf, sizeof(buf), "unsupported hash algorithms: %.100s", hashes);
     661           0 :                 close_connection(mid);
     662           0 :                 return mapi_setError(mid, buf, __func__, MERROR);
     663             :         }
     664             : 
     665        1461 :         mnstr_set_bigendian(mid->from, strcmp(byteo, "BIG") == 0);
     666             : 
     667        1461 :         char *p = buf;
     668        1461 :         int remaining = sizeof(buf);
     669        1461 :         int n;
     670             : #define CHECK_SNPRINTF(...) \
     671             :         do { \
     672             :                 n = snprintf(p, remaining, __VA_ARGS__); \
     673             :                 if (n < remaining) { \
     674             :                         remaining -= n; \
     675             :                         p += n; \
     676             :                 } else { \
     677             :                         mapi_setError(mid, "combination of database name and user name too long", __func__, MERROR); \
     678             :                         free(hash); \
     679             :                         close_connection(mid); \
     680             :                         return mid->error; \
     681             :                 } \
     682             :         } while (0)
     683             : 
     684             : #ifdef WORDS_BIGENDIAN
     685             :         char *our_endian = "BIG";
     686             : #else
     687        1461 :         char *our_endian = "LIT";
     688             : #endif
     689             :         /* note: if we make the database field an empty string, it
     690             :                 * means we want the default.  However, it *should* be there. */
     691        1461 :         const char *language = msetting_string(mid->settings, MP_LANGUAGE);
     692        1461 :         const char *database = msetting_string(mid->settings, MP_DATABASE);
     693        1461 :         CHECK_SNPRINTF("%s:%s:%s:%s:%s:FILETRANS:",
     694             :                         our_endian,
     695             :                         username, hash,
     696             :                         language, database);
     697             : 
     698        1461 :         if (mid->handshake_options > MAPI_HANDSHAKE_AUTOCOMMIT) {
     699        1455 :                 CHECK_SNPRINTF("auto_commit=%d", msetting_bool(mid->settings, MP_AUTOCOMMIT));
     700             :         }
     701        1461 :         if (mid->handshake_options > MAPI_HANDSHAKE_REPLY_SIZE) {
     702        1455 :                 CHECK_SNPRINTF(",reply_size=%ld", msetting_long(mid->settings, MP_REPLYSIZE));
     703             :         }
     704        1461 :         if (mid->handshake_options > MAPI_HANDSHAKE_SIZE_HEADER) {
     705        1455 :                 CHECK_SNPRINTF(",size_header=%d", mid->sizeheader); // with underscore, despite X command without
     706             :         }
     707        1461 :         if (mid->handshake_options > MAPI_HANDSHAKE_COLUMNAR_PROTOCOL) {
     708        1455 :                 CHECK_SNPRINTF(",columnar_protocol=%d", mid->columnar_protocol);
     709             :         }
     710        1461 :         if (mid->handshake_options > MAPI_HANDSHAKE_TIME_ZONE) {
     711        1455 :                 CHECK_SNPRINTF(",time_zone=%ld", msetting_long(mid->settings, MP_TIMEZONE));
     712             :         }
     713        1461 :         if (mid->handshake_options > 0) {
     714        1455 :                 CHECK_SNPRINTF(":");
     715             :         }
     716        1461 :         CHECK_SNPRINTF("\n");
     717             : 
     718        1461 :         free(hash);
     719             : 
     720        1461 :         len = strlen(buf);
     721        1461 :         mapi_log_data(mid, "HANDSHAKE SEND", buf, len);
     722        1461 :         len = mnstr_write(mid->to, buf, 1, len);
     723        1461 :         check_stream(mid, mid->to, len, "Could not send initial byte sequence", mid->error);
     724        1461 :         len = mnstr_flush(mid->to, MNSTR_FLUSH_DATA);
     725        1461 :         check_stream(mid, mid->to, len, "Could not send initial byte sequence", mid->error);
     726             : 
     727             :         // Clear the redirects before we receive new ones
     728        1461 :         for (char **r = mid->redirects; *r != NULL; r++) {
     729           0 :                 free(*r);
     730           0 :                 *r = NULL;
     731             :         }
     732             : 
     733             :         /* consume the welcome message from the server */
     734        1461 :         hdl = mapi_new_handle(mid);
     735        1461 :         if (hdl == NULL) {
     736           0 :                 close_connection(mid);
     737           0 :                 return MERROR;
     738             :         }
     739        1461 :         mid->active = hdl;
     740        1461 :         read_into_cache(hdl, 0);
     741        1461 :         if (mid->error) {
     742          12 :                 char *errorstr = NULL;
     743          12 :                 MapiMsg error;
     744          12 :                 struct MapiResultSet *result;
     745             :                 /* propagate error from result to mid, the error probably is in
     746             :                  * the last produced result, not the first
     747             :                  * mapi_close_handle clears the errors, so save them first */
     748          24 :                 for (result = hdl->result; result; result = result->next) {
     749          12 :                         errorstr = result->errorstr;
     750          12 :                         result->errorstr = NULL;     /* clear these so errorstr doesn't get freed */
     751             :                 }
     752          12 :                 if (!errorstr)
     753           0 :                         errorstr = mid->errorstr;
     754          12 :                 error = mid->error;
     755             : 
     756          12 :                 if (hdl->result)
     757          12 :                         hdl->result->errorstr = NULL;     /* clear these so errorstr doesn't get freed */
     758          12 :                 mid->errorstr = NULL;
     759          12 :                 mapi_close_handle(hdl);
     760          12 :                 mapi_setError(mid, errorstr, __func__, error);
     761          12 :                 if (errorstr != mapi_nomem)
     762          12 :                         free(errorstr); /* now free it after a copy has been made */
     763          12 :                 close_connection(mid);
     764          12 :                 return mid->error;
     765             :         }
     766        1449 :         if (hdl->result && hdl->result->cache.line) {
     767             :                 int i;
     768             :                 size_t motdlen = 0;
     769             :                 struct MapiResultSet *result = hdl->result;
     770             : 
     771           0 :                 for (i = 0; i < result->cache.writer; i++) {
     772           0 :                         if (result->cache.line[i].rows) {
     773           0 :                                 char **r;
     774           0 :                                 int m;
     775           0 :                                 switch (result->cache.line[i].rows[0]) {
     776           0 :                                 case '#':
     777           0 :                                         motdlen += strlen(result->cache.line[i].rows) + 1;
     778           0 :                                         break;
     779             :                                 case '^':
     780             :                                         r = mid->redirects;
     781             :                                         m = NELEM(mid->redirects) - 1;
     782           0 :                                         while (*r != NULL && m > 0) {
     783           0 :                                                 m--;
     784           0 :                                                 r++;
     785             :                                         }
     786           0 :                                         if (m == 0)
     787             :                                                 break;
     788           0 :                                         *r++ = strdup(result->cache.line[i].rows + 1);
     789           0 :                                         *r = NULL;
     790           0 :                                         break;
     791             :                                 }
     792             :                         }
     793             :                 }
     794           0 :                 if (motdlen > 0) {
     795           0 :                         mid->motd = malloc(motdlen + 1);
     796           0 :                         *mid->motd = 0;
     797           0 :                         char *p = mid->motd;
     798           0 :                         for (i = 0; i < result->cache.writer; i++)
     799           0 :                                 if (result->cache.line[i].rows && result->cache.line[i].rows[0] == '#') {
     800           0 :                                         p = stpcpy(p, result->cache.line[i].rows);
     801           0 :                                         p = stpcpy(p, "\n");
     802             :                                 }
     803             :                 }
     804             : 
     805           0 :                 if (*mid->redirects != NULL) {
     806             :                         /* redirect, looks like:
     807             :                          * ^mapi:monetdb://localhost:50001/test?lang=sql&user=monetdb
     808             :                          * or
     809             :                          * ^mapi:merovingian://proxy?database=test */
     810             : 
     811             :                         /* first see if we reached our redirection limit */
     812           0 :                         if (mid->redircnt >= mid->redirmax) {
     813           0 :                                 mapi_close_handle(hdl);
     814           0 :                                 mapi_setError(mid, "too many redirects", __func__, MERROR);
     815           0 :                                 close_connection(mid);
     816           0 :                                 return mid->error;
     817             :                         }
     818           0 :                         mid->redircnt++;
     819             : 
     820             :                         /* we only implement following the first */
     821           0 :                         char *red = mid->redirects[0];
     822             : 
     823           0 :                         const char *error_message = NULL;
     824           0 :                         if ((error_message = msettings_parse_url(mid->settings, red))
     825           0 :                             || (error_message = msettings_validate(mid->settings))
     826             :                         ) {
     827           0 :                                 mapi_close_handle(hdl);
     828           0 :                                 close_connection(mid);
     829           0 :                                 MapiMsg err = mapi_printError(
     830             :                                         mid, __func__, MERROR,
     831             :                                         "%s: %s",
     832             :                                         error_message,
     833             :                                         red);
     834           0 :                                 return err;
     835             :                         }
     836             : 
     837           0 :                         if (strncmp("mapi:merovingian", red, 16) == 0) {
     838             :                                 // do not close the connection so caller knows to restart handshake
     839           0 :                                 mapi_log_record(mid, "HANDSHAKE", "Restarting handshake on current socket");
     840           0 :                                 assert(mid->connected);
     841             :                         } else {
     842           0 :                                 mapi_log_record(mid, "HANDSHAKE", "Redirected elsewhere, closing socket");
     843           0 :                                 close_connection(mid);
     844             :                         }
     845           0 :                         return MREDIRECT;
     846             :                 }
     847             :         }
     848        1449 :         mapi_close_handle(hdl);
     849             : 
     850        1449 :         if (mid->trace)
     851           0 :                 printf("connection established\n");
     852             : 
     853             :         // I don't understand this assert.
     854        1449 :         if (!msettings_lang_is_sql(mid->settings))
     855         194 :                 return mid->error;
     856             : 
     857        1255 :         if (mid->error != MOK)
     858             :                 return mid->error;
     859             : 
     860             :         /* use X commands to send options that couldn't be sent in the handshake */
     861             :         /* tell server about auto_complete and cache limit if handshake options weren't used */
     862        1255 :         bool autocommit = msetting_bool(mid->settings, MP_AUTOCOMMIT);
     863        1255 :         if (mid->handshake_options <= MAPI_HANDSHAKE_AUTOCOMMIT && autocommit != msetting_bool(msettings_default, MP_AUTOCOMMIT)) {
     864           0 :                 char buf[50];
     865           0 :                 sprintf(buf, "%d", !!autocommit);
     866           0 :                 MapiMsg result = mapi_Xcommand(mid, "auto_commit", buf);
     867           0 :                 if (result != MOK)
     868           0 :                         return mid->error;
     869             :         }
     870        1255 :         long replysize = msetting_long(mid->settings, MP_REPLYSIZE);
     871        1255 :         if (mid->handshake_options <= MAPI_HANDSHAKE_REPLY_SIZE && replysize != msetting_long(msettings_default, MP_REPLYSIZE)) {
     872           0 :                 char buf[50];
     873           0 :                 sprintf(buf, "%ld", replysize);
     874           0 :                 MapiMsg result = mapi_Xcommand(mid, "reply_size", buf);
     875           0 :                 if (result != MOK)
     876           0 :                         return mid->error;
     877             :         }
     878        1255 :         if (mid->handshake_options <= MAPI_HANDSHAKE_SIZE_HEADER && mid->sizeheader != MapiStructDefaults.sizeheader) {
     879           0 :                 char buf[50];
     880           0 :                 sprintf(buf, "%d", !!mid->sizeheader);
     881           0 :                 MapiMsg result = mapi_Xcommand(mid, "sizeheader", buf); // no underscore!
     882           0 :                 if (result != MOK)
     883           0 :                         return mid->error;
     884             :         }
     885             :         // There is no if  (mid->handshake_options <= MAPI_HANDSHAKE_COLUMNAR_PROTOCOL && mid->columnar_protocol != MapiStructDefaults.columnar_protocol)
     886             :         // The reason is that columnar_protocol is very new. If it isn't supported in the handshake it isn't supported at
     887             :         // all so sending the Xcommand would just give an error.
     888        1255 :         if (mid->handshake_options <= MAPI_HANDSHAKE_TIME_ZONE) {
     889           0 :                 mapi_set_time_zone(mid, msetting_long(mid->settings, MP_TIMEZONE));
     890             :         }
     891             : 
     892        1255 :         if (mid->error == MOK)
     893        1255 :                 send_all_clientinfo(mid);
     894             : 
     895        1255 :         return mid->error;
     896             : 
     897             : }

Generated by: LCOV version 1.14