LCOV - code coverage report
Current view: top level - clients/mapilib - connect.c (source / functions) Hit Total Coverage
Test: coverage.info Lines: 279 452 61.7 %
Date: 2024-10-03 20:03:20 Functions: 10 10 100.0 %

          Line data    Source code
       1             : /*
       2             :  * SPDX-License-Identifier: MPL-2.0
       3             :  *
       4             :  * This Source Code Form is subject to the terms of the Mozilla Public
       5             :  * License, v. 2.0.  If a copy of the MPL was not distributed with this
       6             :  * file, You can obtain one at http://mozilla.org/MPL/2.0/.
       7             :  *
       8             :  * Copyright 2024 MonetDB Foundation;
       9             :  * Copyright August 2008 - 2023 MonetDB B.V.;
      10             :  * Copyright 1997 - July 2008 CWI.
      11             :  */
      12             : 
      13             : #include "monetdb_config.h"
      14             : #include "stream.h"           /* include before mapi.h */
      15             : #include "stream_socket.h"
      16             : #include "mapi.h"
      17             : #include "mapi_prompt.h"
      18             : #include "mcrypt.h"
      19             : #include "matomic.h"
      20             : #include "mstring.h"
      21             : #include "mutils.h"
      22             : 
      23             : #include "mapi_intern.h"
      24             : 
      25             : #ifdef HAVE_SYS_SOCKET_H
      26             : # include <arpa/inet.h>                   /* addr_in */
      27             : # include <unistd.h>                      /* gethostname() */
      28             : #else /* UNIX specific */
      29             : #ifdef HAVE_WINSOCK_H                   /* Windows specific */
      30             : # include <winsock.h>
      31             : #endif
      32             : #endif
      33             : 
      34             : 
      35             : #ifdef HAVE_SYS_UN_H
      36             : #define DO_UNIX_DOMAIN (1)
      37             : #else
      38             : #define DO_UNIX_DOMAIN (0)
      39             : #endif
      40             : 
      41             : #ifdef _MSC_VER
      42             : #define SOCKET_STRERROR()       wsaerror(WSAGetLastError())
      43             : #else
      44             : #define SOCKET_STRERROR()       strerror(errno)
      45             : #endif
      46             : 
      47             : 
      48             : static MapiMsg scan_sockets(Mapi mid);
      49             : static MapiMsg connect_socket(Mapi mid);
      50             : static MapiMsg connect_socket_tcp(Mapi mid);
      51             : static SOCKET connect_socket_tcp_addr(Mapi mid, struct addrinfo *addr);
      52             : static MapiMsg mapi_handshake(Mapi mid);
      53             : 
      54             : #ifndef HAVE_OPENSSL
      55             : // The real implementation is in connect_openssl.c.
      56             : MapiMsg
      57             : wrap_tls(Mapi mid, SOCKET sock)
      58             : {
      59             :         closesocket(sock);
      60             :         return mapi_setError(mid, "Cannot connect to monetdbs://, not built with OpenSSL support", __func__, MERROR);
      61             : }
      62             : #endif // HAVE_OPENSSL
      63             : 
      64             : #ifndef HAVE_SYS_UN_H
      65             : 
      66             : MapiMsg
      67             : connect_socket_unix(Mapi mid)
      68             : {
      69             :         return mapi_setError(mid, "Unix domain sockets not supported", __func__, MERROR);
      70             : }
      71             : 
      72             : static MapiMsg
      73             : scan_unix_sockets(Mapi mid)
      74             : {
      75             :         return mapi_setError(mid, "Unix domain sockets not supported", __func__, MERROR);
      76             : }
      77             : 
      78             : #endif
      79             : 
      80             : 
      81             : 
      82             : 
      83             : /* (Re-)establish a connection with the server. */
      84             : MapiMsg
      85        1441 : mapi_reconnect(Mapi mid)
      86             : {
      87        1441 :         char *err = NULL;
      88        1441 :         if (!msettings_validate(mid->settings, &err)) {
      89           0 :                 mapi_setError(mid, err, __func__, MERROR);
      90           0 :                 free(err);
      91           0 :                 return MERROR;
      92             :         }
      93             : 
      94             :         // If neither host nor port are given, scan the Unix domain sockets in
      95             :         // /tmp and see if any of them serve this database.
      96             :         // Otherwise, just try to connect to what was given.
      97        1441 :         if (msettings_connect_scan(mid->settings))
      98           2 :                 return scan_sockets(mid);
      99             :         else
     100        1439 :                 return establish_connection(mid);
     101             : }
     102             : 
     103             : static MapiMsg
     104           2 : scan_sockets(Mapi mid)
     105             : {
     106           2 :         if (scan_unix_sockets(mid) == MOK)
     107             :                 return MOK;
     108             : 
     109             :         // When the Unix sockets have been scanned we can freely modify 'original'.
     110           2 :         msettings_error errmsg = msetting_set_string(mid->settings, MP_HOST, "localhost");
     111           2 :         char *allocated_errmsg = NULL;
     112           2 :         if (!errmsg && !msettings_validate(mid->settings, &allocated_errmsg)) {
     113           0 :                 errmsg = allocated_errmsg;
     114             :         }
     115           2 :         if (errmsg) {
     116           0 :                 MapiMsg err = mapi_setError(mid, errmsg, __func__, MERROR);
     117           0 :                 free(allocated_errmsg);
     118           0 :                 return err;
     119             :         }
     120           2 :         return establish_connection(mid);
     121             : }
     122             : 
     123             : /* (Re-)establish a connection with the server. */
     124             : MapiMsg
     125        1441 : establish_connection(Mapi mid)
     126             : {
     127        1441 :         if (mid->connected) {
     128           1 :                 mapi_log_record(mid, "CONN", "Found leftover open connection");
     129           1 :                 close_connection(mid);
     130             :         }
     131             : 
     132             :         MapiMsg msg = MREDIRECT;
     133        2867 :         while (msg == MREDIRECT) {
     134             :                 // Generally at this point we need to set up a new TCP or Unix
     135             :                 // domain connection.
     136             :                 //
     137             :                 // The only exception is if mapi_handshake() below has decided
     138             :                 // that the handshake must be restarted on the existing
     139             :                 // connection.
     140        1441 :                 if (!mid->connected) {
     141        1441 :                         msg = connect_socket(mid);
     142        1441 :                         if (msg != MOK)
     143          15 :                                 return msg;
     144             :                 }
     145        1426 :                 msg = mapi_handshake(mid);
     146             :         }
     147             : 
     148             :         // Switch from MP_CONNECT_TIMEOUT to MP_REPLY_TIMEOUT
     149        1426 :         if (msg == MOK) {
     150        1414 :                 long connect_timeout = msetting_long(mid->settings, MP_CONNECT_TIMEOUT);
     151        1414 :                 long reply_timeout = msetting_long(mid->settings, MP_REPLY_TIMEOUT);
     152        1414 :                 if (connect_timeout > 0 || reply_timeout > 0) {
     153           0 :                         if (reply_timeout < 0)
     154             :                                 reply_timeout = 0;
     155           0 :                         msg = mapi_timeout(mid, reply_timeout);
     156             :                 }
     157             :         }
     158             : 
     159             :         return msg;
     160             : }
     161             : 
     162             : static MapiMsg
     163        1441 : connect_socket(Mapi mid)
     164             : {
     165        1441 :         assert(!mid->connected);
     166        1441 :         const char *sockname = msettings_connect_unix(mid->settings);
     167        1441 :         const char *tcp_host = msettings_connect_tcp(mid->settings);
     168        1441 :         long timeout = msetting_long(mid->settings, MP_CONNECT_TIMEOUT);
     169             : 
     170        1441 :         assert(*sockname || *tcp_host);
     171        2867 :         do {
     172        1441 :                 if (*sockname && connect_socket_unix(mid) == MOK)
     173             :                         break;
     174        1371 :                 if (*tcp_host && connect_socket_tcp(mid) == MOK)
     175             :                         break;
     176          15 :                 assert(mid->error == MERROR);
     177          15 :                 mid->error = MERROR; // in case assert above was not enabled
     178          15 :                 return mid->error;
     179             :         } while (0);
     180             : 
     181             :         // the socket code may have set SO_SNDTIMEO and SO_RCVTIMEO but
     182             :         // the mapi layer doesn't know this yet.
     183        1426 :         if (timeout > 0)
     184           0 :                 mapi_timeout(mid, timeout);
     185             : 
     186        1426 :         mid->connected = true;
     187        1426 :         return MOK;
     188             : }
     189             : 
     190             : MapiMsg
     191        1422 : wrap_socket(Mapi mid, SOCKET sock)
     192             : {
     193             :         // do not use check_stream here yet because the socket is not yet in 'mid'
     194        1422 :         stream *broken_stream = NULL;
     195        1422 :         MapiMsg msg;
     196        1422 :         stream *rstream = NULL;
     197        1422 :         stream *wstream = NULL;
     198             : 
     199        1422 :         wstream = socket_wstream(sock, "Mapi client write");
     200        1422 :         if (wstream == NULL || mnstr_errnr(wstream) != MNSTR_NO__ERROR) {
     201           0 :                 broken_stream = wstream;
     202           0 :                 goto bailout;
     203             :         }
     204             : 
     205        1422 :         rstream = socket_rstream(sock, "Mapi client write");
     206        1422 :         if (rstream == NULL || mnstr_errnr(rstream) != MNSTR_NO__ERROR) {
     207           0 :                 broken_stream = rstream;
     208           0 :                 goto bailout;
     209             :         }
     210             : 
     211        1422 :         msg = mapi_wrap_streams(mid, rstream, wstream);
     212        1422 :         if (msg != MOK)
     213           0 :                 goto bailout;
     214             :         return MOK;
     215             : 
     216           0 : bailout:
     217           0 :         if (rstream)
     218           0 :                 mnstr_destroy(rstream);
     219           0 :         if (wstream)
     220           0 :                 mnstr_destroy(wstream);
     221           0 :         closesocket(sock);
     222           0 :         if (broken_stream) {
     223           0 :                 char *error_message = "create stream from socket";
     224             :                 // malloc failure is the only way these calls could have failed
     225           0 :                 return mapi_printError(mid, __func__, MERROR, "%s: %s", error_message, mnstr_peek_error(broken_stream));
     226             :         } else {
     227             :                 return MERROR;
     228             :         }
     229             : }
     230             : 
     231             : static MapiMsg
     232        1371 : connect_socket_tcp(Mapi mid)
     233             : {
     234        1371 :         int ret;
     235             : 
     236        1371 :         bool use_tls = msetting_bool(mid->settings, MP_TLS);
     237        1371 :         const char *host = msettings_connect_tcp(mid->settings);
     238        1371 :         int port = msettings_connect_port(mid->settings);
     239             : 
     240        1371 :         assert(host);
     241        1371 :         char portbuf[10];
     242        1371 :         snprintf(portbuf, sizeof(portbuf), "%d", port);
     243             : 
     244        1371 :         mapi_log_record(mid, "CONN", "Connecting to %s:%d", host, port);
     245             : 
     246        1371 :         struct addrinfo hints = (struct addrinfo) {
     247             :                 .ai_family = AF_UNSPEC,
     248             :                 .ai_socktype = SOCK_STREAM,
     249             :                 .ai_protocol = IPPROTO_TCP,
     250             :         };
     251        1371 :         struct addrinfo *addresses;
     252        1371 :         ret = getaddrinfo(host, portbuf, &hints, &addresses);
     253        1371 :         if (ret != 0) {
     254           0 :                 return mapi_printError(
     255             :                         mid, __func__, MERROR,
     256             :                         "getaddrinfo %s:%s failed: %s", host, portbuf, gai_strerror(ret));
     257             :         }
     258        1371 :         if (addresses == NULL) {
     259           0 :                 return mapi_printError(
     260             :                         mid, __func__, MERROR,
     261             :                         "getaddrinfo return 0 addresses");
     262             :         }
     263             : 
     264             :         assert(addresses);
     265             :         SOCKET s;
     266        1385 :         for (struct addrinfo *addr = addresses; addr; addr = addr->ai_next) {
     267        1378 :                 s = connect_socket_tcp_addr(mid, addr);
     268        1378 :                 if (s != INVALID_SOCKET)
     269             :                         break;
     270             :         }
     271        1371 :         freeaddrinfo(addresses);
     272        1371 :         if (s == INVALID_SOCKET) {
     273             :                 // connect_socket_tcp_addr has already set an error message
     274             :                 return MERROR;
     275             :         }
     276             : 
     277             :         /* compare our own address with that of our peer and
     278             :          * if they are the same, we were connected to our own
     279             :          * socket, so then we can't use this connection */
     280        1364 :         union {
     281             :                 struct sockaddr_storage ss;
     282             :                 struct sockaddr_in i4;
     283             :                 struct sockaddr_in6 i6;
     284             :         } myaddr, praddr;
     285        1364 :         socklen_t myaddrlen, praddrlen;
     286        1364 :         myaddrlen = (socklen_t) sizeof(myaddr.ss);
     287        1364 :         praddrlen = (socklen_t) sizeof(praddr.ss);
     288        2728 :         if (getsockname(s, (struct sockaddr *) &myaddr.ss, &myaddrlen) == 0 &&
     289        1364 :                 getpeername(s, (struct sockaddr *) &praddr.ss, &praddrlen) == 0 &&
     290        2728 :                 myaddr.ss.ss_family == praddr.ss.ss_family &&
     291             :                 (myaddr.ss.ss_family == AF_INET
     292        1364 :                 ? myaddr.i4.sin_port == praddr.i4.sin_port
     293           0 :                 : myaddr.i6.sin6_port == praddr.i6.sin6_port) &&
     294             :                 (myaddr.ss.ss_family == AF_INET
     295           0 :                 ? myaddr.i4.sin_addr.s_addr == praddr.i4.sin_addr.s_addr
     296           0 :                 : memcmp(myaddr.i6.sin6_addr.s6_addr,
     297             :                         praddr.i6.sin6_addr.s6_addr,
     298             :                         sizeof(praddr.i6.sin6_addr.s6_addr)) == 0)) {
     299           0 :                 closesocket(s);
     300           0 :                 return mapi_setError(mid, "connected to self", __func__, MERROR);
     301             :         }
     302             : 
     303        1364 :         mapi_log_record(mid, "CONN", "Network connection established");
     304        1364 :         MapiMsg msg = use_tls ? wrap_tls(mid, s) : wrap_socket(mid, s);
     305        1364 :         if (msg != MOK)
     306             :                 return msg;
     307             : 
     308             :         return msg;
     309             : }
     310             : 
     311             : static SOCKET
     312        1378 : connect_socket_tcp_addr(Mapi mid, struct addrinfo *info)
     313             : {
     314        1378 :         long timeout = msetting_long(mid->settings, MP_CONNECT_TIMEOUT);
     315             : 
     316        1378 :         if (mid->tracelog) {
     317           0 :                 char addrbuf[100] = {0};
     318           0 :                 const char *addrtext;
     319           0 :                 int port;
     320           0 :                 if (info->ai_family == AF_INET) {
     321           0 :                         struct sockaddr_in *addr4 = (struct sockaddr_in*)info->ai_addr;
     322           0 :                         port = ntohs(addr4->sin_port);
     323           0 :                         void *addr = &addr4->sin_addr;
     324           0 :                         addrtext = inet_ntop(info->ai_family, addr, addrbuf, sizeof(addrbuf));
     325           0 :                 } else if (info->ai_family == AF_INET6) {
     326           0 :                         struct sockaddr_in6 *addr6 = (struct sockaddr_in6*)info->ai_addr;
     327           0 :                         port = ntohs(addr6->sin6_port);
     328           0 :                         void *addr = &addr6->sin6_addr;
     329           0 :                         addrtext = inet_ntop(info->ai_family, addr, addrbuf, sizeof(addrbuf));
     330             :                 } else {
     331             :                         port = -1;
     332             :                         addrtext = NULL;
     333             :                 }
     334           0 :                 mapi_log_record(mid, "CONN", "Trying IP %s port %d with timeout %ld", addrtext ? addrtext : "<UNKNOWN>", port, timeout);
     335             :         }
     336             : 
     337             : 
     338        1378 :         int socktype = info->ai_socktype;
     339             : #ifdef SOCK_CLOEXEC
     340        1378 :         socktype |= SOCK_CLOEXEC;
     341             : #endif
     342             : 
     343        1378 :         SOCKET s =  socket(info->ai_family, socktype, info->ai_protocol);
     344        1378 :         if (s == INVALID_SOCKET) {
     345          14 :                 mapi_printError(
     346             :                         mid, __func__, MERROR,
     347           7 :                         "could not create TCP socket: %s", SOCKET_STRERROR());
     348           7 :                 return INVALID_SOCKET;
     349             :         }
     350             : 
     351             : #if !defined(SOCK_CLOEXEC) && defined(HAVE_FCNTL)
     352             :         (void) fcntl(s, F_SETFD, FD_CLOEXEC);
     353             : #endif
     354             : 
     355        1371 :         if (timeout > 0) {
     356           0 :                 struct timeval tv = {
     357           0 :                         .tv_sec = timeout / 1000,
     358           0 :                         .tv_usec = timeout % 1000,
     359             :                 };
     360             :                 /* cast to char * for Windows, no harm on "normal" systems */
     361           0 :                 if (
     362           0 :                         setsockopt(s, SOL_SOCKET, SO_SNDTIMEO, (char*)&tv, sizeof(tv)) == SOCKET_ERROR
     363           0 :                         || setsockopt(s, SOL_SOCKET, SO_RCVTIMEO, (char*)&tv, sizeof(tv)) == SOCKET_ERROR
     364             :                 ) {
     365           0 :                         closesocket(s);
     366           0 :                         return mapi_printError(
     367             :                                 mid, __func__, MERROR,
     368           0 :                                 "could not set connect timeout: %s", strerror(errno));
     369             :                 }
     370             :         }
     371             : 
     372             :         // cast addrlen to int to satisfy Windows.
     373        1371 :         if (connect(s, info->ai_addr, (int)info->ai_addrlen) == SOCKET_ERROR) {
     374           7 :                 mapi_printError(
     375             :                         mid, __func__, MERROR,
     376           7 :                         "could not connect: %s", SOCKET_STRERROR());
     377           7 :                 closesocket(s);
     378           7 :                 return INVALID_SOCKET;
     379             :         }
     380             : 
     381             :         return s;
     382             : }
     383             : 
     384             : static const char *
     385        1218 : base_name(const char *file)
     386             : {
     387        1218 :         char *p = strrchr(file, '/');
     388             : #ifdef _MSC_VER
     389             :         char *q = strrchr(file, '\\');
     390             :         if (q != NULL) {
     391             :                 if (p == NULL || p < q)
     392             :                         p = q;
     393             :         }
     394             : #endif
     395        1218 :         if (p)
     396        1218 :                 return p + 1;
     397             :         return file;
     398             : }
     399             : 
     400             : static void
     401        1221 : send_all_clientinfo(Mapi mid)
     402             : {
     403        1221 :         msettings *mp = mid->settings;
     404        1221 :         void *free_this = NULL;
     405        1221 :         if (!mid->clientinfo_supported)
     406           2 :                 return;
     407        1221 :         if (!msetting_bool(mp, MP_CLIENT_INFO))
     408             :                 return;
     409             : 
     410             : 
     411        1219 :         static char hostname[120] = { 0 };
     412        1219 :         if (hostname[0] == '\0') {
     413         218 :                 if (gethostname(hostname, sizeof(hostname)) != 0)
     414           0 :                         hostname[0] = '\0';
     415         218 :                 hostname[sizeof(hostname) - 1] = '\0';
     416             :         }
     417        1219 :         const char *application_name = msetting_string(mp, MP_CLIENT_APPLICATION);
     418        1219 :         if (!application_name[0]) {
     419        1218 :                 application_name = get_bin_path();
     420        1218 :                 if (application_name) {
     421        1218 :                         free_this = strdup(application_name);
     422        1218 :                         application_name = base_name(application_name);
     423             :                 }
     424             :         }
     425        1219 :         const char *client_remark = msetting_string(mp, MP_CLIENT_REMARK);
     426        1219 :         long pid = getpid();
     427             : 
     428        1219 :         char *buf = NULL;
     429        1219 :         size_t pos = 0, cap = 200;
     430             : 
     431        1219 :         if (hostname[0])
     432        1219 :                 reallocprintf(&buf, &pos, &cap, "ClientHostname=%s\n", hostname);
     433        1219 :         if (application_name && application_name[0])
     434        1219 :                 reallocprintf(&buf, &pos, &cap, "ApplicationName=%s\n", application_name);
     435        1219 :         reallocprintf(&buf, &pos, &cap, "ClientLibrary=");
     436        1219 :         if (mid->clientprefix)
     437          32 :                 reallocprintf(&buf, &pos, &cap, "%s / ", mid->clientprefix);
     438        1219 :         reallocprintf(&buf, &pos, &cap, "libmapi %s\n", MONETDB_VERSION);
     439        1219 :         if (client_remark[0])
     440           6 :                 reallocprintf(&buf, &pos, &cap, "ClientRemark=%s\n", client_remark);
     441        1219 :         if (pid > 0)
     442        1219 :                 reallocprintf(&buf, &pos, &cap, "ClientPid=%ld\n", pid);
     443             : 
     444        1219 :         if (pos > 1) {
     445        1219 :                 assert(buf[pos - 1] == '\n');
     446        1219 :                 pos--;
     447        1219 :                 buf[pos] = '\0';
     448             :         }
     449             : 
     450        1219 :         if (pos <= cap)
     451        1219 :                 mapi_Xcommand(mid, "clientinfo", buf);
     452             : 
     453        1219 :         free(buf);
     454        1219 :         free(free_this);
     455             : }
     456             : 
     457             : static MapiMsg
     458        1426 : mapi_handshake(Mapi mid)
     459             : {
     460        1426 :         char buf[BLOCK];
     461        1426 :         size_t len;
     462        1426 :         MapiHdl hdl;
     463             : 
     464        1426 :         const char *username = msetting_string(mid->settings, MP_USER);
     465        1426 :         const char *password = msetting_string(mid->settings, MP_PASSWORD);
     466             : 
     467             :         /* consume server challenge */
     468        1426 :         len = mnstr_read_block(mid->from, buf, 1, sizeof(buf));
     469        1426 :         check_stream(mid, mid->from, len, "Connection terminated while starting handshake", (mid->blk.eos = true, mid->error));
     470             : 
     471        1425 :         mapi_log_data(mid, "RECV HANDSHAKE", buf, len);
     472             : 
     473        1425 :         assert(len < sizeof(buf));
     474        1425 :         buf[len] = 0;
     475             : 
     476        1425 :         if (len == 0) {
     477           0 :                 mapi_setError(mid, "Challenge string is not valid, it is empty", __func__, MERROR);
     478           0 :                 return mid->error;
     479             :         }
     480             :         /* buf at this point looks like "challenge:servertype:protover[:.*]" */
     481             : 
     482        1425 :         char *strtok_state = NULL;
     483        1425 :         char *chal = strtok_r(buf, ":", &strtok_state);
     484        1425 :         if (chal == NULL) {
     485           0 :                 mapi_setError(mid, "Challenge string is not valid, challenge not found", __func__, MERROR);
     486           0 :                 close_connection(mid);
     487           0 :                 return mid->error;
     488             :         }
     489             : 
     490        1425 :         char *server = strtok_r(NULL, ":", &strtok_state);
     491        1425 :         if (server == NULL) {
     492           0 :                 mapi_setError(mid, "Challenge string is not valid, server not found", __func__, MERROR);
     493           0 :                 close_connection(mid);
     494           0 :                 return mid->error;
     495             :         }
     496             : 
     497        1425 :         char *protover = strtok_r(NULL, ":", &strtok_state);
     498        1425 :         if (protover == NULL) {
     499           0 :                 mapi_setError(mid, "Challenge string is not valid, protocol not found", __func__, MERROR);
     500           0 :                 close_connection(mid);
     501           0 :                 return mid->error;
     502             :         }
     503        1425 :         int pversion = atoi(protover);
     504        1425 :         if (pversion != 9) {
     505             :                 /* because the headers changed, and because it makes no sense to
     506             :                  * try and be backwards (or forwards) compatible, we bail out
     507             :                  * with a friendly message saying so */
     508           0 :                 snprintf(buf, sizeof(buf), "unsupported protocol version: %d, "
     509             :                          "this client only supports version 9", pversion);
     510           0 :                 mapi_setError(mid, buf, __func__, MERROR);
     511           0 :                 close_connection(mid);
     512           0 :                 return mid->error;
     513             :         }
     514             : 
     515        1425 :         char *hashes = strtok_r(NULL, ":", &strtok_state);
     516        1425 :         if (hashes == NULL) {
     517             :                 /* protocol violation, not enough fields */
     518           0 :                 mapi_setError(mid, "Not enough fields in challenge string", __func__, MERROR);
     519           0 :                 close_connection(mid);
     520           0 :                 return mid->error;
     521             :         }
     522        1425 :         char *algsv[] = {
     523             :                 "RIPEMD160",
     524             :                 "SHA512",
     525             :                 "SHA384",
     526             :                 "SHA256",
     527             :                 "SHA224",
     528             :                 "SHA1",
     529             :                 NULL
     530             :         };
     531        1425 :         char **algs = algsv;
     532             : 
     533             :         /* rBuCQ9WTn3:mserver:9:RIPEMD160,SHA256,SHA1,MD5:LIT:SHA1: */
     534             : 
     535        1425 :         if (!*username || !*password) {
     536           0 :                 mapi_setError(mid, "username and password must be set",
     537             :                                 __func__, MERROR);
     538           0 :                 close_connection(mid);
     539           0 :                 return mid->error;
     540             :         }
     541             : 
     542             :         /* the database has sent a list of supported hashes to us, it's
     543             :                 * in the form of a comma separated list and in the variable
     544             :                 * rest.  We try to use the strongest algorithm. */
     545             : 
     546             : 
     547             :         /* in rest now should be the byte order of the server */
     548        1425 :         char *byteo = strtok_r(NULL, ":", &strtok_state);
     549             : 
     550             :         /* Proto v9 is like v8, but mandates that the password is a
     551             :                 * hash, that is salted like in v8.  The hash algorithm is
     552             :                 * specified in the 6th field.  If we don't support it, we
     553             :                 * can't login. */
     554        1425 :         char *serverhash = strtok_r(NULL, ":", &strtok_state);
     555             : 
     556        1425 :         char *handshake_options = strtok_r(NULL, ":", &strtok_state);
     557        1425 :         if (handshake_options) {
     558        1420 :                 if (sscanf(handshake_options, "sql=%d", &mid->handshake_options) != 1) {
     559           0 :                         mapi_setError(mid, "invalid handshake options",
     560             :                                         __func__, MERROR);
     561           0 :                         close_connection(mid);
     562           0 :                         return mid->error;
     563             :                 }
     564             :         }
     565             : 
     566             :         /* skip the binary option */
     567        1425 :         char *binary = strtok_r(NULL, ":", &strtok_state);
     568        1425 :         (void)binary;
     569             : 
     570        1425 :         char *oobintr = strtok_r(NULL, ":", &strtok_state);
     571        1425 :         if (oobintr) {
     572        1420 :                 if (strcmp(oobintr, "OOBINTR=1") == 0) {
     573        1420 :                         mid->oobintr = true;
     574             :                 }
     575             :         }
     576             : 
     577        1425 :         char *clientinfo = strtok_r(NULL, ":", &strtok_state);
     578        1425 :         if (clientinfo) {
     579        1420 :                 if (strcmp(clientinfo, "CLIENTINFO") == 0) {
     580        1420 :                         mid->clientinfo_supported = true;
     581             :                 }
     582             :         }
     583             : 
     584             :         /* hash password, if not already */
     585        1425 :         if (password[0] != '\1') {
     586         237 :                 char *pwdhash = NULL;
     587         237 :                 if (strcmp(serverhash, "RIPEMD160") == 0) {
     588           0 :                         pwdhash = mcrypt_RIPEMD160Sum(password,
     589             :                                                         strlen(password));
     590         237 :                 } else if (strcmp(serverhash, "SHA512") == 0) {
     591         237 :                         pwdhash = mcrypt_SHA512Sum(password,
     592             :                                                         strlen(password));
     593           0 :                 } else if (strcmp(serverhash, "SHA384") == 0) {
     594           0 :                         pwdhash = mcrypt_SHA384Sum(password,
     595             :                                                         strlen(password));
     596           0 :                 } else if (strcmp(serverhash, "SHA256") == 0) {
     597           0 :                         pwdhash = mcrypt_SHA256Sum(password,
     598             :                                                         strlen(password));
     599           0 :                 } else if (strcmp(serverhash, "SHA224") == 0) {
     600           0 :                         pwdhash = mcrypt_SHA224Sum(password,
     601             :                                                         strlen(password));
     602           0 :                 } else if (strcmp(serverhash, "SHA1") == 0) {
     603           0 :                         pwdhash = mcrypt_SHA1Sum(password,
     604             :                                                         strlen(password));
     605             :                 } else {
     606           0 :                         snprintf(buf, sizeof(buf), "server requires unknown hash '%.100s'",
     607             :                                         serverhash);
     608           0 :                         close_connection(mid);
     609           0 :                         return mapi_setError(mid, buf, __func__, MERROR);
     610             :                 }
     611             : 
     612         237 :                 if (pwdhash == NULL) {
     613           0 :                         snprintf(buf, sizeof(buf), "allocation failure or unknown hash '%.100s'",
     614             :                                         serverhash);
     615           0 :                         close_connection(mid);
     616           0 :                         return mapi_setError(mid, buf, __func__, MERROR);
     617             :                 }
     618             : 
     619         237 :                 char *replacement_password = malloc(1 + strlen(pwdhash) + 1);
     620         237 :                 if (replacement_password == NULL) {
     621           0 :                         free(pwdhash);
     622           0 :                         close_connection(mid);
     623           0 :                         return mapi_setError(mid, "malloc failed", __func__, MERROR);
     624             :                 }
     625         237 :                 sprintf(replacement_password, "\1%s", pwdhash);
     626         237 :                 free(pwdhash);
     627         237 :                 msettings_error errmsg = msetting_set_string(mid->settings, MP_PASSWORD, replacement_password);
     628         237 :                 free(replacement_password);
     629         237 :                 if (errmsg != NULL) {
     630           0 :                         close_connection(mid);
     631           0 :                         return mapi_setError(mid, "could not stow hashed password", __func__, MERROR);
     632             :                 }
     633             :         }
     634             : 
     635             : 
     636        1425 :         const char *pw = msetting_string(mid->settings, MP_PASSWORD);
     637        1425 :         assert(*pw == '\1');
     638        1425 :         pw++;
     639             : 
     640        1425 :         char *hash = NULL;
     641        1425 :         for (; *algs != NULL; algs++) {
     642             :                 /* TODO: make this actually obey the separation by
     643             :                         * commas, and only allow full matches */
     644        1425 :                 if (strstr(hashes, *algs) != NULL) {
     645        1425 :                         char *pwh = mcrypt_hashPassword(*algs, pw, chal);
     646        1425 :                         size_t len;
     647        1425 :                         if (pwh == NULL)
     648           0 :                                 continue;
     649        1425 :                         len = strlen(pwh) + strlen(*algs) + 3 /* {}\0 */;
     650        1425 :                         hash = malloc(len);
     651        1425 :                         if (hash == NULL) {
     652           0 :                                 close_connection(mid);
     653           0 :                                 free(pwh);
     654           0 :                                 return mapi_setError(mid, "malloc failure", __func__, MERROR);
     655             :                         }
     656        1425 :                         snprintf(hash, len, "{%s}%s", *algs, pwh);
     657        1425 :                         free(pwh);
     658        1425 :                         break;
     659             :                 }
     660             :         }
     661        1425 :         if (hash == NULL) {
     662             :                 /* the server doesn't support what we can */
     663           0 :                 snprintf(buf, sizeof(buf), "unsupported hash algorithms: %.100s", hashes);
     664           0 :                 close_connection(mid);
     665           0 :                 return mapi_setError(mid, buf, __func__, MERROR);
     666             :         }
     667             : 
     668        1425 :         mnstr_set_bigendian(mid->from, strcmp(byteo, "BIG") == 0);
     669             : 
     670        1425 :         char *p = buf;
     671        1425 :         int remaining = sizeof(buf);
     672        1425 :         int n;
     673             : #define CHECK_SNPRINTF(...) \
     674             :         do { \
     675             :                 n = snprintf(p, remaining, __VA_ARGS__); \
     676             :                 if (n < remaining) { \
     677             :                         remaining -= n; \
     678             :                         p += n; \
     679             :                 } else { \
     680             :                         mapi_setError(mid, "combination of database name and user name too long", __func__, MERROR); \
     681             :                         free(hash); \
     682             :                         close_connection(mid); \
     683             :                         return mid->error; \
     684             :                 } \
     685             :         } while (0)
     686             : 
     687             : #ifdef WORDS_BIGENDIAN
     688             :         char *our_endian = "BIG";
     689             : #else
     690        1425 :         char *our_endian = "LIT";
     691             : #endif
     692             :         /* note: if we make the database field an empty string, it
     693             :                 * means we want the default.  However, it *should* be there. */
     694        1425 :         const char *language = msetting_string(mid->settings, MP_LANGUAGE);
     695        1425 :         const char *database = msetting_string(mid->settings, MP_DATABASE);
     696        1425 :         CHECK_SNPRINTF("%s:%s:%s:%s:%s:FILETRANS:",
     697             :                         our_endian,
     698             :                         username, hash,
     699             :                         language, database);
     700             : 
     701        1425 :         if (mid->handshake_options > MAPI_HANDSHAKE_AUTOCOMMIT) {
     702        1420 :                 CHECK_SNPRINTF("auto_commit=%d", msetting_bool(mid->settings, MP_AUTOCOMMIT));
     703             :         }
     704        1425 :         if (mid->handshake_options > MAPI_HANDSHAKE_REPLY_SIZE) {
     705        1420 :                 CHECK_SNPRINTF(",reply_size=%ld", msetting_long(mid->settings, MP_REPLYSIZE));
     706             :         }
     707        1425 :         if (mid->handshake_options > MAPI_HANDSHAKE_SIZE_HEADER) {
     708        1420 :                 CHECK_SNPRINTF(",size_header=%d", mid->sizeheader); // with underscore, despite X command without
     709             :         }
     710        1425 :         if (mid->handshake_options > MAPI_HANDSHAKE_COLUMNAR_PROTOCOL) {
     711        1420 :                 CHECK_SNPRINTF(",columnar_protocol=%d", mid->columnar_protocol);
     712             :         }
     713        1425 :         if (mid->handshake_options > MAPI_HANDSHAKE_TIME_ZONE) {
     714        1420 :                 CHECK_SNPRINTF(",time_zone=%ld", msetting_long(mid->settings, MP_TIMEZONE));
     715             :         }
     716        1425 :         if (mid->handshake_options > 0) {
     717        1420 :                 CHECK_SNPRINTF(":");
     718             :         }
     719        1425 :         CHECK_SNPRINTF("\n");
     720             : 
     721        1425 :         free(hash);
     722             : 
     723        1425 :         len = strlen(buf);
     724        1425 :         mapi_log_data(mid, "HANDSHAKE SEND", buf, len);
     725        1425 :         len = mnstr_write(mid->to, buf, 1, len);
     726        1425 :         check_stream(mid, mid->to, len, "Could not send initial byte sequence", mid->error);
     727        1425 :         len = mnstr_flush(mid->to, MNSTR_FLUSH_DATA);
     728        1425 :         check_stream(mid, mid->to, len, "Could not send initial byte sequence", mid->error);
     729             : 
     730             :         // Clear the redirects before we receive new ones
     731        1425 :         for (char **r = mid->redirects; *r != NULL; r++) {
     732           0 :                 free(*r);
     733           0 :                 *r = NULL;
     734             :         }
     735             : 
     736             :         /* consume the welcome message from the server */
     737        1425 :         hdl = mapi_new_handle(mid);
     738        1425 :         if (hdl == NULL) {
     739           0 :                 close_connection(mid);
     740           0 :                 return MERROR;
     741             :         }
     742        1425 :         mid->active = hdl;
     743        1425 :         read_into_cache(hdl, 0);
     744        1425 :         if (mid->error) {
     745          11 :                 char *errorstr = NULL;
     746          11 :                 MapiMsg error;
     747          11 :                 struct MapiResultSet *result;
     748             :                 /* propagate error from result to mid, the error probably is in
     749             :                  * the last produced result, not the first
     750             :                  * mapi_close_handle clears the errors, so save them first */
     751          22 :                 for (result = hdl->result; result; result = result->next) {
     752          11 :                         errorstr = result->errorstr;
     753          11 :                         result->errorstr = NULL;     /* clear these so errorstr doesn't get freed */
     754             :                 }
     755          11 :                 if (!errorstr)
     756           0 :                         errorstr = mid->errorstr;
     757          11 :                 error = mid->error;
     758             : 
     759          11 :                 if (hdl->result)
     760          11 :                         hdl->result->errorstr = NULL;     /* clear these so errorstr doesn't get freed */
     761          11 :                 mid->errorstr = NULL;
     762          11 :                 mapi_close_handle(hdl);
     763          11 :                 mapi_setError(mid, errorstr, __func__, error);
     764          11 :                 if (errorstr != mapi_nomem)
     765          11 :                         free(errorstr); /* now free it after a copy has been made */
     766          11 :                 close_connection(mid);
     767          11 :                 return mid->error;
     768             :         }
     769        1414 :         if (hdl->result && hdl->result->cache.line) {
     770             :                 int i;
     771             :                 size_t motdlen = 0;
     772             :                 struct MapiResultSet *result = hdl->result;
     773             : 
     774           0 :                 for (i = 0; i < result->cache.writer; i++) {
     775           0 :                         if (result->cache.line[i].rows) {
     776           0 :                                 char **r;
     777           0 :                                 int m;
     778           0 :                                 switch (result->cache.line[i].rows[0]) {
     779           0 :                                 case '#':
     780           0 :                                         motdlen += strlen(result->cache.line[i].rows) + 1;
     781           0 :                                         break;
     782             :                                 case '^':
     783             :                                         r = mid->redirects;
     784             :                                         m = NELEM(mid->redirects) - 1;
     785           0 :                                         while (*r != NULL && m > 0) {
     786           0 :                                                 m--;
     787           0 :                                                 r++;
     788             :                                         }
     789           0 :                                         if (m == 0)
     790             :                                                 break;
     791           0 :                                         *r++ = strdup(result->cache.line[i].rows + 1);
     792           0 :                                         *r = NULL;
     793           0 :                                         break;
     794             :                                 }
     795             :                         }
     796             :                 }
     797           0 :                 if (motdlen > 0) {
     798           0 :                         mid->motd = malloc(motdlen + 1);
     799           0 :                         *mid->motd = 0;
     800           0 :                         for (i = 0; i < result->cache.writer; i++)
     801           0 :                                 if (result->cache.line[i].rows && result->cache.line[i].rows[0] == '#') {
     802           0 :                                         strcat(mid->motd, result->cache.line[i].rows);
     803           0 :                                         strcat(mid->motd, "\n");
     804             :                                 }
     805             :                 }
     806             : 
     807           0 :                 if (*mid->redirects != NULL) {
     808             :                         /* redirect, looks like:
     809             :                          * ^mapi:monetdb://localhost:50001/test?lang=sql&user=monetdb
     810             :                          * or
     811             :                          * ^mapi:merovingian://proxy?database=test */
     812             : 
     813             :                         /* first see if we reached our redirection limit */
     814           0 :                         if (mid->redircnt >= mid->redirmax) {
     815           0 :                                 mapi_close_handle(hdl);
     816           0 :                                 mapi_setError(mid, "too many redirects", __func__, MERROR);
     817           0 :                                 close_connection(mid);
     818           0 :                                 return mid->error;
     819             :                         }
     820           0 :                         mid->redircnt++;
     821             : 
     822             :                         /* we only implement following the first */
     823           0 :                         char *red = mid->redirects[0];
     824             : 
     825           0 :                         char *error_message = NULL;
     826           0 :                         if (!msettings_parse_url(mid->settings, red, &error_message)
     827           0 :                             || !msettings_validate(mid->settings, &error_message)
     828             :                         ) {
     829           0 :                                 mapi_close_handle(hdl);
     830           0 :                                 close_connection(mid);
     831           0 :                                 MapiMsg err = mapi_printError(
     832             :                                         mid, __func__, MERROR,
     833             :                                         "%s: %s",
     834           0 :                                         error_message ? error_message : "invalid redirect",
     835             :                                         red);
     836           0 :                                 free(error_message);
     837           0 :                                 return err;
     838             :                         }
     839             : 
     840           0 :                         if (strncmp("mapi:merovingian", red, 16) == 0) {
     841             :                                 // do not close the connection so caller knows to restart handshake
     842           0 :                                 mapi_log_record(mid, "HANDSHAKE", "Restarting handshake on current socket");
     843           0 :                                 assert(mid->connected);
     844             :                         } else {
     845           0 :                                 mapi_log_record(mid, "HANDSHAKE", "Redirected elsewhere, closing socket");
     846           0 :                                 close_connection(mid);
     847             :                         }
     848           0 :                         return MREDIRECT;
     849             :                 }
     850             :         }
     851        1414 :         mapi_close_handle(hdl);
     852             : 
     853        1414 :         if (mid->trace)
     854           0 :                 printf("connection established\n");
     855             : 
     856             :         // I don't understand this assert.
     857        1414 :         if (!msettings_lang_is_sql(mid->settings))
     858         193 :                 return mid->error;
     859             : 
     860        1221 :         if (mid->error != MOK)
     861             :                 return mid->error;
     862             : 
     863             :         /* use X commands to send options that couldn't be sent in the handshake */
     864             :         /* tell server about auto_complete and cache limit if handshake options weren't used */
     865        1221 :         bool autocommit = msetting_bool(mid->settings, MP_AUTOCOMMIT);
     866        1221 :         if (mid->handshake_options <= MAPI_HANDSHAKE_AUTOCOMMIT && autocommit != msetting_bool(msettings_default, MP_AUTOCOMMIT)) {
     867           0 :                 char buf[50];
     868           0 :                 sprintf(buf, "%d", !!autocommit);
     869           0 :                 MapiMsg result = mapi_Xcommand(mid, "auto_commit", buf);
     870           0 :                 if (result != MOK)
     871           0 :                         return mid->error;
     872             :         }
     873        1221 :         long replysize = msetting_long(mid->settings, MP_REPLYSIZE);
     874        1221 :         if (mid->handshake_options <= MAPI_HANDSHAKE_REPLY_SIZE && replysize != msetting_long(msettings_default, MP_REPLYSIZE)) {
     875           0 :                 char buf[50];
     876           0 :                 sprintf(buf, "%ld", replysize);
     877           0 :                 MapiMsg result = mapi_Xcommand(mid, "reply_size", buf);
     878           0 :                 if (result != MOK)
     879           0 :                         return mid->error;
     880             :         }
     881        1221 :         if (mid->handshake_options <= MAPI_HANDSHAKE_SIZE_HEADER && mid->sizeheader != MapiStructDefaults.sizeheader) {
     882           0 :                 char buf[50];
     883           0 :                 sprintf(buf, "%d", !!mid->sizeheader);
     884           0 :                 MapiMsg result = mapi_Xcommand(mid, "sizeheader", buf); // no underscore!
     885           0 :                 if (result != MOK)
     886           0 :                         return mid->error;
     887             :         }
     888             :         // There is no if  (mid->handshake_options <= MAPI_HANDSHAKE_COLUMNAR_PROTOCOL && mid->columnar_protocol != MapiStructDefaults.columnar_protocol)
     889             :         // The reason is that columnar_protocol is very new. If it isn't supported in the handshake it isn't supported at
     890             :         // all so sending the Xcommand would just give an error.
     891        1221 :         if (mid->handshake_options <= MAPI_HANDSHAKE_TIME_ZONE) {
     892           0 :                 mapi_set_time_zone(mid, msetting_long(mid->settings, MP_TIMEZONE));
     893             :         }
     894             : 
     895        1221 :         if (mid->error == MOK)
     896        1221 :                 send_all_clientinfo(mid);
     897             : 
     898        1221 :         return mid->error;
     899             : 
     900             : }

Generated by: LCOV version 1.14