Branch data Line data Source code
1 : : // Copyright (c) 2020-2022 The Bitcoin Core developers
2 : : // Distributed under the MIT software license, see the accompanying
3 : : // file COPYING or http://www.opensource.org/licenses/mit-license.php.
4 : :
5 : : #include <common/system.h>
6 : : #include <compat/compat.h>
7 : : #include <logging.h>
8 : : #include <tinyformat.h>
9 : : #include <util/sock.h>
10 : : #include <util/syserror.h>
11 : : #include <util/threadinterrupt.h>
12 : : #include <util/time.h>
13 : :
14 : : #include <memory>
15 : : #include <stdexcept>
16 : : #include <string>
17 : :
18 : : #ifdef USE_POLL
19 : : #include <poll.h>
20 : : #endif
21 : :
22 : 0 : static inline bool IOErrorIsPermanent(int err)
23 : : {
24 [ # # ][ # # ]: 0 : return err != WSAEAGAIN && err != WSAEINTR && err != WSAEWOULDBLOCK && err != WSAEINPROGRESS;
[ # # ]
25 : : }
26 : :
27 : 0 : Sock::Sock(SOCKET s) : m_socket(s) {}
28 : :
29 : 0 : Sock::Sock(Sock&& other)
30 : 0 : {
31 : 0 : m_socket = other.m_socket;
32 : 0 : other.m_socket = INVALID_SOCKET;
33 : 0 : }
34 : :
35 [ # # ]: 0 : Sock::~Sock() { Close(); }
36 : :
37 : 0 : Sock& Sock::operator=(Sock&& other)
38 : : {
39 : 0 : Close();
40 : 0 : m_socket = other.m_socket;
41 : 0 : other.m_socket = INVALID_SOCKET;
42 : 0 : return *this;
43 : : }
44 : :
45 : 0 : ssize_t Sock::Send(const void* data, size_t len, int flags) const
46 : : {
47 : 0 : return send(m_socket, static_cast<const char*>(data), len, flags);
48 : : }
49 : :
50 : 0 : ssize_t Sock::Recv(void* buf, size_t len, int flags) const
51 : : {
52 : 0 : return recv(m_socket, static_cast<char*>(buf), len, flags);
53 : : }
54 : :
55 : 0 : int Sock::Connect(const sockaddr* addr, socklen_t addr_len) const
56 : : {
57 : 0 : return connect(m_socket, addr, addr_len);
58 : : }
59 : :
60 : 0 : int Sock::Bind(const sockaddr* addr, socklen_t addr_len) const
61 : : {
62 : 0 : return bind(m_socket, addr, addr_len);
63 : : }
64 : :
65 : 0 : int Sock::Listen(int backlog) const
66 : : {
67 : 0 : return listen(m_socket, backlog);
68 : : }
69 : :
70 : 0 : std::unique_ptr<Sock> Sock::Accept(sockaddr* addr, socklen_t* addr_len) const
71 : : {
72 : : #ifdef WIN32
73 : : static constexpr auto ERR = INVALID_SOCKET;
74 : : #else
75 : : static constexpr auto ERR = SOCKET_ERROR;
76 : : #endif
77 : :
78 : 0 : std::unique_ptr<Sock> sock;
79 : :
80 [ # # ]: 0 : const auto socket = accept(m_socket, addr, addr_len);
81 : 0 : if (socket != ERR) {
82 : : try {
83 [ # # ]: 0 : sock = std::make_unique<Sock>(socket);
84 [ # # ]: 0 : } catch (const std::exception&) {
85 : : #ifdef WIN32
86 : : closesocket(socket);
87 : : #else
88 [ # # ]: 0 : close(socket);
89 : : #endif
90 [ # # ][ # # ]: 0 : }
91 : 0 : }
92 : :
93 : 0 : return sock;
94 : 0 : }
95 : :
96 : 0 : int Sock::GetSockOpt(int level, int opt_name, void* opt_val, socklen_t* opt_len) const
97 : : {
98 : 0 : return getsockopt(m_socket, level, opt_name, static_cast<char*>(opt_val), opt_len);
99 : : }
100 : :
101 : 0 : int Sock::SetSockOpt(int level, int opt_name, const void* opt_val, socklen_t opt_len) const
102 : : {
103 : 0 : return setsockopt(m_socket, level, opt_name, static_cast<const char*>(opt_val), opt_len);
104 : : }
105 : :
106 : 0 : int Sock::GetSockName(sockaddr* name, socklen_t* name_len) const
107 : : {
108 : 0 : return getsockname(m_socket, name, name_len);
109 : : }
110 : :
111 : 0 : bool Sock::SetNonBlocking() const
112 : : {
113 : : #ifdef WIN32
114 : : u_long on{1};
115 : : if (ioctlsocket(m_socket, FIONBIO, &on) == SOCKET_ERROR) {
116 : : return false;
117 : : }
118 : : #else
119 : 0 : const int flags{fcntl(m_socket, F_GETFL, 0)};
120 [ # # ]: 0 : if (flags == SOCKET_ERROR) {
121 : 0 : return false;
122 : : }
123 [ # # ]: 0 : if (fcntl(m_socket, F_SETFL, flags | O_NONBLOCK) == SOCKET_ERROR) {
124 : 0 : return false;
125 : : }
126 : : #endif
127 : 0 : return true;
128 : 0 : }
129 : :
130 : 0 : bool Sock::IsSelectable() const
131 : : {
132 : : #if defined(USE_POLL) || defined(WIN32)
133 : 0 : return true;
134 : : #else
135 : : return m_socket < FD_SETSIZE;
136 : : #endif
137 : : }
138 : :
139 : 0 : bool Sock::Wait(std::chrono::milliseconds timeout, Event requested, Event* occurred) const
140 : : {
141 : : // We need a `shared_ptr` owning `this` for `WaitMany()`, but don't want
142 : : // `this` to be destroyed when the `shared_ptr` goes out of scope at the
143 : : // end of this function. Create it with a custom noop deleter.
144 : 0 : std::shared_ptr<const Sock> shared{this, [](const Sock*) {}};
145 : :
146 [ # # ][ # # ]: 0 : EventsPerSock events_per_sock{std::make_pair(shared, Events{requested})};
[ # # ][ # # ]
147 : :
148 [ # # ][ # # ]: 0 : if (!WaitMany(timeout, events_per_sock)) {
149 : 0 : return false;
150 : : }
151 : :
152 [ # # ]: 0 : if (occurred != nullptr) {
153 : 0 : *occurred = events_per_sock.begin()->second.occurred;
154 : 0 : }
155 : :
156 : 0 : return true;
157 : 0 : }
158 : :
159 : 0 : bool Sock::WaitMany(std::chrono::milliseconds timeout, EventsPerSock& events_per_sock) const
160 : : {
161 : : #ifdef USE_POLL
162 : 0 : std::vector<pollfd> pfds;
163 [ # # ]: 0 : for (const auto& [sock, events] : events_per_sock) {
164 [ # # ]: 0 : pfds.emplace_back();
165 : 0 : auto& pfd = pfds.back();
166 : 0 : pfd.fd = sock->m_socket;
167 [ # # ]: 0 : if (events.requested & RECV) {
168 : 0 : pfd.events |= POLLIN;
169 : 0 : }
170 [ # # ]: 0 : if (events.requested & SEND) {
171 : 0 : pfd.events |= POLLOUT;
172 : 0 : }
173 : : }
174 : :
175 [ # # ][ # # ]: 0 : if (poll(pfds.data(), pfds.size(), count_milliseconds(timeout)) == SOCKET_ERROR) {
[ # # ]
176 : 0 : return false;
177 : : }
178 : :
179 [ # # ]: 0 : assert(pfds.size() == events_per_sock.size());
180 : 0 : size_t i{0};
181 [ # # ]: 0 : for (auto& [sock, events] : events_per_sock) {
182 [ # # ]: 0 : assert(sock->m_socket == static_cast<SOCKET>(pfds[i].fd));
183 : 0 : events.occurred = 0;
184 [ # # ]: 0 : if (pfds[i].revents & POLLIN) {
185 : 0 : events.occurred |= RECV;
186 : 0 : }
187 [ # # ]: 0 : if (pfds[i].revents & POLLOUT) {
188 : 0 : events.occurred |= SEND;
189 : 0 : }
190 [ # # ]: 0 : if (pfds[i].revents & (POLLERR | POLLHUP)) {
191 : 0 : events.occurred |= ERR;
192 : 0 : }
193 : 0 : ++i;
194 : : }
195 : :
196 : 0 : return true;
197 : : #else
198 : : fd_set recv;
199 : : fd_set send;
200 : : fd_set err;
201 : : FD_ZERO(&recv);
202 : : FD_ZERO(&send);
203 : : FD_ZERO(&err);
204 : : SOCKET socket_max{0};
205 : :
206 : : for (const auto& [sock, events] : events_per_sock) {
207 : : if (!sock->IsSelectable()) {
208 : : return false;
209 : : }
210 : : const auto& s = sock->m_socket;
211 : : if (events.requested & RECV) {
212 : : FD_SET(s, &recv);
213 : : }
214 : : if (events.requested & SEND) {
215 : : FD_SET(s, &send);
216 : : }
217 : : FD_SET(s, &err);
218 : : socket_max = std::max(socket_max, s);
219 : : }
220 : :
221 : : timeval tv = MillisToTimeval(timeout);
222 : :
223 : : if (select(socket_max + 1, &recv, &send, &err, &tv) == SOCKET_ERROR) {
224 : : return false;
225 : : }
226 : :
227 : : for (auto& [sock, events] : events_per_sock) {
228 : : const auto& s = sock->m_socket;
229 : : events.occurred = 0;
230 : : if (FD_ISSET(s, &recv)) {
231 : : events.occurred |= RECV;
232 : : }
233 : : if (FD_ISSET(s, &send)) {
234 : : events.occurred |= SEND;
235 : : }
236 : : if (FD_ISSET(s, &err)) {
237 : : events.occurred |= ERR;
238 : : }
239 : : }
240 : :
241 : : return true;
242 : : #endif /* USE_POLL */
243 : 0 : }
244 : :
245 : 0 : void Sock::SendComplete(Span<const unsigned char> data,
246 : : std::chrono::milliseconds timeout,
247 : : CThreadInterrupt& interrupt) const
248 : : {
249 : 0 : const auto deadline = GetTime<std::chrono::milliseconds>() + timeout;
250 : 0 : size_t sent{0};
251 : :
252 : 0 : for (;;) {
253 : 0 : const ssize_t ret{Send(data.data() + sent, data.size() - sent, MSG_NOSIGNAL)};
254 : :
255 [ # # ]: 0 : if (ret > 0) {
256 : 0 : sent += static_cast<size_t>(ret);
257 [ # # ]: 0 : if (sent == data.size()) {
258 : 0 : break;
259 : : }
260 : 0 : } else {
261 : 0 : const int err{WSAGetLastError()};
262 [ # # ]: 0 : if (IOErrorIsPermanent(err)) {
263 [ # # ][ # # ]: 0 : throw std::runtime_error(strprintf("send(): %s", NetworkErrorString(err)));
[ # # ][ # # ]
264 : : }
265 : : }
266 : :
267 : 0 : const auto now = GetTime<std::chrono::milliseconds>();
268 : :
269 [ # # ]: 0 : if (now >= deadline) {
270 [ # # ][ # # ]: 0 : throw std::runtime_error(strprintf(
[ # # ][ # # ]
271 : 0 : "Send timeout (sent only %u of %u bytes before that)", sent, data.size()));
272 : : }
273 : :
274 [ # # ]: 0 : if (interrupt) {
275 [ # # ][ # # ]: 0 : throw std::runtime_error(strprintf(
[ # # ][ # # ]
276 : 0 : "Send interrupted (sent only %u of %u bytes before that)", sent, data.size()));
277 : : }
278 : :
279 : : // Wait for a short while (or the socket to become ready for sending) before retrying
280 : : // if nothing was sent.
281 : 0 : const auto wait_time = std::min(deadline - now, std::chrono::milliseconds{MAX_WAIT_FOR_IO});
282 : 0 : (void)Wait(wait_time, SEND);
283 : : }
284 : 0 : }
285 : :
286 : 0 : void Sock::SendComplete(Span<const char> data,
287 : : std::chrono::milliseconds timeout,
288 : : CThreadInterrupt& interrupt) const
289 : : {
290 : 0 : SendComplete(MakeUCharSpan(data), timeout, interrupt);
291 : 0 : }
292 : :
293 : 0 : std::string Sock::RecvUntilTerminator(uint8_t terminator,
294 : : std::chrono::milliseconds timeout,
295 : : CThreadInterrupt& interrupt,
296 : : size_t max_data) const
297 : : {
298 : 0 : const auto deadline = GetTime<std::chrono::milliseconds>() + timeout;
299 : 0 : std::string data;
300 : 0 : bool terminator_found{false};
301 : :
302 : : // We must not consume any bytes past the terminator from the socket.
303 : : // One option is to read one byte at a time and check if we have read a terminator.
304 : : // However that is very slow. Instead, we peek at what is in the socket and only read
305 : : // as many bytes as possible without crossing the terminator.
306 : : // Reading 64 MiB of random data with 262526 terminator chars takes 37 seconds to read
307 : : // one byte at a time VS 0.71 seconds with the "peek" solution below. Reading one byte
308 : : // at a time is about 50 times slower.
309 : :
310 : 0 : for (;;) {
311 [ # # ]: 0 : if (data.size() >= max_data) {
312 [ # # ][ # # ]: 0 : throw std::runtime_error(
313 [ # # ]: 0 : strprintf("Received too many bytes without a terminator (%u)", data.size()));
314 : : }
315 : :
316 : : char buf[512];
317 : :
318 [ # # ][ # # ]: 0 : const ssize_t peek_ret{Recv(buf, std::min(sizeof(buf), max_data - data.size()), MSG_PEEK)};
319 : :
320 [ # # # ]: 0 : switch (peek_ret) {
321 : : case -1: {
322 : 0 : const int err{WSAGetLastError()};
323 [ # # ][ # # ]: 0 : if (IOErrorIsPermanent(err)) {
324 [ # # ][ # # ]: 0 : throw std::runtime_error(strprintf("recv(): %s", NetworkErrorString(err)));
[ # # ][ # # ]
325 : : }
326 : 0 : break;
327 : : }
328 : : case 0:
329 [ # # ]: 0 : throw std::runtime_error("Connection unexpectedly closed by peer");
330 : : default:
331 : 0 : auto end = buf + peek_ret;
332 [ # # ]: 0 : auto terminator_pos = std::find(buf, end, terminator);
333 : 0 : terminator_found = terminator_pos != end;
334 : :
335 [ # # ]: 0 : const size_t try_len{terminator_found ? terminator_pos - buf + 1 :
336 : 0 : static_cast<size_t>(peek_ret)};
337 : :
338 [ # # ]: 0 : const ssize_t read_ret{Recv(buf, try_len, 0)};
339 : :
340 [ # # ]: 0 : if (read_ret < 0 || static_cast<size_t>(read_ret) != try_len) {
341 [ # # ][ # # ]: 0 : throw std::runtime_error(
342 [ # # ]: 0 : strprintf("recv() returned %u bytes on attempt to read %u bytes but previous "
343 : : "peek claimed %u bytes are available",
344 : : read_ret, try_len, peek_ret));
345 : : }
346 : :
347 : : // Don't include the terminator in the output.
348 [ # # ]: 0 : const size_t append_len{terminator_found ? try_len - 1 : try_len};
349 : :
350 [ # # ]: 0 : data.append(buf, buf + append_len);
351 : :
352 [ # # ]: 0 : if (terminator_found) {
353 : 0 : return data;
354 : : }
355 : 0 : }
356 : :
357 [ # # ]: 0 : const auto now = GetTime<std::chrono::milliseconds>();
358 : :
359 [ # # ][ # # ]: 0 : if (now >= deadline) {
360 [ # # ][ # # ]: 0 : throw std::runtime_error(strprintf(
[ # # ][ # # ]
361 : 0 : "Receive timeout (received %u bytes without terminator before that)", data.size()));
362 : : }
363 : :
364 [ # # ][ # # ]: 0 : if (interrupt) {
365 [ # # ][ # # ]: 0 : throw std::runtime_error(strprintf(
[ # # ][ # # ]
366 : : "Receive interrupted (received %u bytes without terminator before that)",
367 : 0 : data.size()));
368 : : }
369 : :
370 : : // Wait for a short while (or the socket to become ready for reading) before retrying.
371 [ # # ][ # # ]: 0 : const auto wait_time = std::min(deadline - now, std::chrono::milliseconds{MAX_WAIT_FOR_IO});
[ # # ]
372 [ # # ]: 0 : (void)Wait(wait_time, RECV);
373 : : }
374 [ # # ]: 0 : }
375 : :
376 : 0 : bool Sock::IsConnected(std::string& errmsg) const
377 : : {
378 [ # # ]: 0 : if (m_socket == INVALID_SOCKET) {
379 : 0 : errmsg = "not connected";
380 : 0 : return false;
381 : : }
382 : :
383 : : char c;
384 [ # # # ]: 0 : switch (Recv(&c, sizeof(c), MSG_PEEK)) {
385 : : case -1: {
386 : 0 : const int err = WSAGetLastError();
387 [ # # ]: 0 : if (IOErrorIsPermanent(err)) {
388 : 0 : errmsg = NetworkErrorString(err);
389 : 0 : return false;
390 : : }
391 : 0 : return true;
392 : : }
393 : : case 0:
394 : 0 : errmsg = "closed";
395 : 0 : return false;
396 : : default:
397 : 0 : return true;
398 : : }
399 : 0 : }
400 : :
401 : 0 : void Sock::Close()
402 : : {
403 [ # # ]: 0 : if (m_socket == INVALID_SOCKET) {
404 : 0 : return;
405 : : }
406 : : #ifdef WIN32
407 : : int ret = closesocket(m_socket);
408 : : #else
409 : 0 : int ret = close(m_socket);
410 : : #endif
411 [ # # ]: 0 : if (ret) {
412 [ # # ][ # # ]: 0 : LogPrintf("Error closing socket %d: %s\n", m_socket, NetworkErrorString(WSAGetLastError()));
[ # # ][ # # ]
413 : 0 : }
414 : 0 : m_socket = INVALID_SOCKET;
415 : 0 : }
416 : :
417 : 0 : bool Sock::operator==(SOCKET s) const
418 : : {
419 : 0 : return m_socket == s;
420 : : };
421 : :
422 : 0 : std::string NetworkErrorString(int err)
423 : : {
424 : : #if defined(WIN32)
425 : : return Win32ErrorString(err);
426 : : #else
427 : : // On BSD sockets implementations, NetworkErrorString is the same as SysErrorString.
428 : 0 : return SysErrorString(err);
429 : : #endif
430 : : }
|