Line data Source code
1 : /*
2 : * SPDX-License-Identifier: MPL-2.0
3 : *
4 : * This Source Code Form is subject to the terms of the Mozilla Public
5 : * License, v. 2.0. If a copy of the MPL was not distributed with this
6 : * file, You can obtain one at http://mozilla.org/MPL/2.0/.
7 : *
8 : * Copyright 2024, 2025 MonetDB Foundation;
9 : * Copyright August 2008 - 2023 MonetDB B.V.;
10 : * Copyright 1997 - July 2008 CWI.
11 : */
12 :
13 : #include "monetdb_config.h"
14 : #include "stream.h" /* include before mapi.h */
15 : #include "stream_socket.h"
16 : #include "mapi.h"
17 : #include "mapi_prompt.h"
18 : #include "mcrypt.h"
19 : #include "matomic.h"
20 : #include "mstring.h"
21 : #include "mutils.h"
22 :
23 : #include "mapi_intern.h"
24 :
25 : #define MAX_SCAN (24)
26 :
27 : MapiMsg
28 2 : scan_unix_sockets(Mapi mid)
29 : {
30 2 : struct {
31 : int port;
32 : int priority;
33 : } candidates[MAX_SCAN];
34 2 : int ncandidates = 0;
35 2 : DIR *dir = NULL;
36 2 : struct dirent *entry;
37 :
38 2 : const char *sockdir = msetting_string(mid->settings, MP_SOCKDIR);
39 2 : size_t len = strlen(sockdir);
40 2 : char *namebuf = malloc(len + 50);
41 2 : if (namebuf == NULL)
42 0 : return mapi_setError(mid, "malloc failed", __func__, MERROR);
43 2 : strcpy(namebuf, sockdir);
44 2 : strcpy(namebuf + len, "/.s.monetdb.PORTXXXXX");
45 2 : char *put_port_here = strrchr(namebuf, 'P');
46 :
47 2 : msettings *original = mid->settings;
48 2 : mid->settings = NULL; // invalid state, will fix it before use and on return
49 :
50 2 : mapi_log_record(mid, "CONN", "Scanning %s for Unix domain sockets", sockdir);
51 :
52 : // Make a list of Unix domain sockets in /tmp
53 2 : uid_t me = getuid();
54 2 : dir = opendir(sockdir);
55 2 : if (dir) {
56 208 : while (ncandidates < MAX_SCAN && (entry = readdir(dir)) != NULL) {
57 206 : const char *basename = entry->d_name;
58 206 : if (strncmp(basename, ".s.monetdb.", 11) != 0 || basename[11] == '\0' || strlen(basename) > 20)
59 206 : continue;
60 :
61 0 : char *end;
62 0 : long port = strtol(basename + 11, &end, 10);
63 0 : if (port < 1 || port > 65535 || *end)
64 0 : continue;
65 :
66 0 : sprintf(put_port_here, "%ld", port);
67 0 : struct stat st;
68 0 : if (stat(namebuf, &st) < 0 || !S_ISSOCK(st.st_mode))
69 0 : continue;
70 :
71 0 : candidates[ncandidates].port = port;
72 0 : candidates[ncandidates++].priority = st.st_uid == me ? 0 : 1;
73 : }
74 2 : closedir(dir);
75 : }
76 :
77 2 : mapi_log_record(mid, "CONN", "Found %d Unix domain sockets", ncandidates);
78 :
79 : // Try those owned by us first, then all others
80 6 : for (int round = 0; round < 2; round++) {
81 4 : for (int i = 0; i < ncandidates; i++) {
82 0 : if (candidates[i].priority != round)
83 0 : continue;
84 :
85 0 : assert(!mid->connected);
86 0 : assert(mid->settings == NULL);
87 0 : mid->settings = msettings_clone(original);
88 0 : if (!mid->settings) {
89 0 : mid->settings = original;
90 0 : free(namebuf);
91 0 : return mapi_setError(mid, "malloc failed", __func__, MERROR);
92 : }
93 0 : msettings_error errmsg = msetting_set_long(mid->settings, MP_PORT, candidates[i].port);
94 0 : if (!errmsg)
95 0 : errmsg = msettings_validate(mid->settings);
96 0 : if (errmsg) {
97 0 : mapi_setError(mid, errmsg, __func__, MERROR);
98 0 : free(namebuf);
99 0 : msettings_destroy(mid->settings);
100 0 : mid->settings = original;
101 0 : return MERROR;
102 : }
103 0 : MapiMsg msg = establish_connection(mid);
104 0 : if (msg == MOK) {
105 : // do not restore original
106 0 : msettings_destroy(original);
107 0 : free(namebuf);
108 0 : return MOK;
109 : } else {
110 0 : msettings_destroy(mid->settings);
111 0 : mid->settings = NULL;
112 : // now we're ready to try another one
113 : }
114 : }
115 : }
116 :
117 2 : free(namebuf);
118 2 : assert(mid->settings == NULL);
119 2 : mid->settings = original;
120 2 : mapi_log_record(mid, "CONN", "All %d Unix domain sockets failed. Falling back to TCP", ncandidates);
121 : return MERROR;
122 : }
123 :
124 :
125 : MapiMsg
126 81 : connect_socket_unix(Mapi mid)
127 : {
128 81 : const char *sockname = msettings_connect_unix(mid->settings);
129 81 : assert (*sockname != '\0');
130 81 : long timeout = msetting_long(mid->settings, MP_CONNECT_TIMEOUT);
131 :
132 81 : mapi_log_record(mid, "CONN", "Connecting to Unix domain socket %s with timeout %ld", sockname, timeout);
133 :
134 81 : struct sockaddr_un userver;
135 81 : if (strlen(sockname) >= sizeof(userver.sun_path)) {
136 0 : return mapi_printError(mid, __func__, MERROR, "path name '%s' too long", sockname);
137 : }
138 :
139 : // Create the socket, taking care of CLOEXEC and SNDTIMEO
140 :
141 : #ifdef SOCK_CLOEXEC
142 81 : int s = socket(PF_UNIX, SOCK_STREAM | SOCK_CLOEXEC, 0);
143 : #else
144 : int s = socket(PF_UNIX, SOCK_STREAM, 0);
145 : #endif
146 81 : if (s == INVALID_SOCKET) {
147 0 : return mapi_printError(
148 : mid, __func__, MERROR,
149 0 : "could not create Unix domain socket '%s': %s", sockname, strerror(errno));
150 : }
151 : #if !defined(SOCK_CLOEXEC) && defined(HAVE_FCNTL)
152 : (void) fcntl(s, F_SETFD, FD_CLOEXEC);
153 : #endif
154 :
155 81 : if (timeout > 0) {
156 0 : struct timeval tv = {
157 0 : .tv_sec = timeout / 1000,
158 0 : .tv_usec = timeout % 1000,
159 : };
160 0 : if (
161 0 : setsockopt(s, SOL_SOCKET, SO_SNDTIMEO, &tv, sizeof(tv)) == SOCKET_ERROR
162 0 : || setsockopt(s, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)) == SOCKET_ERROR
163 : ) {
164 0 : closesocket(s);
165 0 : return mapi_printError(
166 : mid, __func__, MERROR,
167 0 : "could not set connect timeout: %s", strerror(errno));
168 : }
169 : }
170 :
171 : // Attempt to connect
172 :
173 81 : userver = (struct sockaddr_un) {
174 : .sun_family = AF_UNIX,
175 : };
176 81 : strcpy_len(userver.sun_path, sockname, sizeof(userver.sun_path));
177 :
178 81 : if (connect(s, (struct sockaddr *) &userver, sizeof(struct sockaddr_un)) == SOCKET_ERROR) {
179 13 : closesocket(s);
180 13 : return mapi_printError(
181 : mid, __func__, MERROR,
182 13 : "connect to Unix domain socket '%s' failed: %s", sockname, strerror(errno));
183 : }
184 :
185 : // Send an initial zero (not NUL) to let the server know we're not passing a file
186 : // descriptor.
187 :
188 68 : ssize_t n = send(s, "0", 1, 0);
189 68 : if (n < 1) {
190 : // used to be if n < 0 but this makes more sense
191 0 : closesocket(s);
192 0 : return mapi_printError(
193 : mid, __func__, MERROR,
194 0 : "could not send initial '0' on Unix domain socket: %s", strerror(errno));
195 : }
196 :
197 68 : return wrap_socket(mid, s);
198 : }
|