Line data Source code
1 : // Copyright (c) 2020-2022 The Bitcoin Core developers
2 : // Distributed under the MIT software license, see the accompanying
3 : // file COPYING or http://www.opensource.org/licenses/mit-license.php.
4 :
5 : #include <common/system.h>
6 : #include <compat/compat.h>
7 : #include <logging.h>
8 : #include <tinyformat.h>
9 : #include <util/sock.h>
10 : #include <util/syserror.h>
11 : #include <util/threadinterrupt.h>
12 : #include <util/time.h>
13 :
14 : #include <memory>
15 : #include <stdexcept>
16 : #include <string>
17 :
18 : #ifdef USE_POLL
19 : #include <poll.h>
20 : #endif
21 :
22 0 : static inline bool IOErrorIsPermanent(int err)
23 : {
24 0 : return err != WSAEAGAIN && err != WSAEINTR && err != WSAEWOULDBLOCK && err != WSAEINPROGRESS;
25 : }
26 :
27 0 : Sock::Sock() : m_socket(INVALID_SOCKET) {}
28 :
29 0 : Sock::Sock(SOCKET s) : m_socket(s) {}
30 :
31 0 : Sock::Sock(Sock&& other)
32 0 : {
33 0 : m_socket = other.m_socket;
34 0 : other.m_socket = INVALID_SOCKET;
35 0 : }
36 :
37 0 : Sock::~Sock() { Close(); }
38 :
39 0 : Sock& Sock::operator=(Sock&& other)
40 : {
41 0 : Close();
42 0 : m_socket = other.m_socket;
43 0 : other.m_socket = INVALID_SOCKET;
44 0 : return *this;
45 : }
46 :
47 0 : SOCKET Sock::Get() const { return m_socket; }
48 :
49 0 : ssize_t Sock::Send(const void* data, size_t len, int flags) const
50 : {
51 0 : return send(m_socket, static_cast<const char*>(data), len, flags);
52 : }
53 :
54 0 : ssize_t Sock::Recv(void* buf, size_t len, int flags) const
55 : {
56 0 : return recv(m_socket, static_cast<char*>(buf), len, flags);
57 : }
58 :
59 0 : int Sock::Connect(const sockaddr* addr, socklen_t addr_len) const
60 : {
61 0 : return connect(m_socket, addr, addr_len);
62 : }
63 :
64 0 : int Sock::Bind(const sockaddr* addr, socklen_t addr_len) const
65 : {
66 0 : return bind(m_socket, addr, addr_len);
67 : }
68 :
69 0 : int Sock::Listen(int backlog) const
70 : {
71 0 : return listen(m_socket, backlog);
72 : }
73 :
74 2 : std::unique_ptr<Sock> Sock::Accept(sockaddr* addr, socklen_t* addr_len) const
75 : {
76 : #ifdef WIN32
77 : static constexpr auto ERR = INVALID_SOCKET;
78 : #else
79 : static constexpr auto ERR = SOCKET_ERROR;
80 : #endif
81 :
82 0 : std::unique_ptr<Sock> sock;
83 :
84 0 : const auto socket = accept(m_socket, addr, addr_len);
85 0 : if (socket != ERR) {
86 : try {
87 0 : sock = std::make_unique<Sock>(socket);
88 0 : } catch (const std::exception&) {
89 : #ifdef WIN32
90 : closesocket(socket);
91 : #else
92 0 : close(socket);
93 : #endif
94 0 : }
95 0 : }
96 :
97 0 : return sock;
98 0 : }
99 :
100 0 : int Sock::GetSockOpt(int level, int opt_name, void* opt_val, socklen_t* opt_len) const
101 : {
102 0 : return getsockopt(m_socket, level, opt_name, static_cast<char*>(opt_val), opt_len);
103 : }
104 :
105 0 : int Sock::SetSockOpt(int level, int opt_name, const void* opt_val, socklen_t opt_len) const
106 : {
107 0 : return setsockopt(m_socket, level, opt_name, static_cast<const char*>(opt_val), opt_len);
108 : }
109 :
110 0 : int Sock::GetSockName(sockaddr* name, socklen_t* name_len) const
111 : {
112 0 : return getsockname(m_socket, name, name_len);
113 : }
114 :
115 0 : bool Sock::SetNonBlocking() const
116 : {
117 : #ifdef WIN32
118 : u_long on{1};
119 : if (ioctlsocket(m_socket, FIONBIO, &on) == SOCKET_ERROR) {
120 : return false;
121 : }
122 : #else
123 0 : const int flags{fcntl(m_socket, F_GETFL, 0)};
124 0 : if (flags == SOCKET_ERROR) {
125 0 : return false;
126 : }
127 0 : if (fcntl(m_socket, F_SETFL, flags | O_NONBLOCK) == SOCKET_ERROR) {
128 0 : return false;
129 : }
130 : #endif
131 0 : return true;
132 0 : }
133 :
134 0 : bool Sock::IsSelectable() const
135 : {
136 : #if defined(USE_POLL) || defined(WIN32)
137 0 : return true;
138 : #else
139 : return m_socket < FD_SETSIZE;
140 : #endif
141 : }
142 :
143 0 : bool Sock::Wait(std::chrono::milliseconds timeout, Event requested, Event* occurred) const
144 : {
145 : // We need a `shared_ptr` owning `this` for `WaitMany()`, but don't want
146 : // `this` to be destroyed when the `shared_ptr` goes out of scope at the
147 : // end of this function. Create it with a custom noop deleter.
148 0 : std::shared_ptr<const Sock> shared{this, [](const Sock*) {}};
149 :
150 0 : EventsPerSock events_per_sock{std::make_pair(shared, Events{requested})};
151 :
152 0 : if (!WaitMany(timeout, events_per_sock)) {
153 0 : return false;
154 : }
155 :
156 0 : if (occurred != nullptr) {
157 0 : *occurred = events_per_sock.begin()->second.occurred;
158 0 : }
159 :
160 0 : return true;
161 0 : }
162 :
163 0 : bool Sock::WaitMany(std::chrono::milliseconds timeout, EventsPerSock& events_per_sock) const
164 : {
165 : #ifdef USE_POLL
166 0 : std::vector<pollfd> pfds;
167 0 : for (const auto& [sock, events] : events_per_sock) {
168 0 : pfds.emplace_back();
169 0 : auto& pfd = pfds.back();
170 0 : pfd.fd = sock->m_socket;
171 0 : if (events.requested & RECV) {
172 0 : pfd.events |= POLLIN;
173 0 : }
174 0 : if (events.requested & SEND) {
175 0 : pfd.events |= POLLOUT;
176 0 : }
177 : }
178 :
179 0 : if (poll(pfds.data(), pfds.size(), count_milliseconds(timeout)) == SOCKET_ERROR) {
180 0 : return false;
181 : }
182 :
183 0 : assert(pfds.size() == events_per_sock.size());
184 0 : size_t i{0};
185 0 : for (auto& [sock, events] : events_per_sock) {
186 0 : assert(sock->m_socket == static_cast<SOCKET>(pfds[i].fd));
187 0 : events.occurred = 0;
188 0 : if (pfds[i].revents & POLLIN) {
189 0 : events.occurred |= RECV;
190 0 : }
191 0 : if (pfds[i].revents & POLLOUT) {
192 0 : events.occurred |= SEND;
193 0 : }
194 0 : if (pfds[i].revents & (POLLERR | POLLHUP)) {
195 0 : events.occurred |= ERR;
196 0 : }
197 0 : ++i;
198 : }
199 :
200 0 : return true;
201 : #else
202 : fd_set recv;
203 : fd_set send;
204 : fd_set err;
205 : FD_ZERO(&recv);
206 : FD_ZERO(&send);
207 : FD_ZERO(&err);
208 : SOCKET socket_max{0};
209 :
210 : for (const auto& [sock, events] : events_per_sock) {
211 : if (!sock->IsSelectable()) {
212 : return false;
213 : }
214 : const auto& s = sock->m_socket;
215 : if (events.requested & RECV) {
216 : FD_SET(s, &recv);
217 : }
218 : if (events.requested & SEND) {
219 : FD_SET(s, &send);
220 : }
221 : FD_SET(s, &err);
222 : socket_max = std::max(socket_max, s);
223 : }
224 :
225 : timeval tv = MillisToTimeval(timeout);
226 :
227 : if (select(socket_max + 1, &recv, &send, &err, &tv) == SOCKET_ERROR) {
228 : return false;
229 : }
230 :
231 : for (auto& [sock, events] : events_per_sock) {
232 : const auto& s = sock->m_socket;
233 : events.occurred = 0;
234 : if (FD_ISSET(s, &recv)) {
235 : events.occurred |= RECV;
236 : }
237 : if (FD_ISSET(s, &send)) {
238 : events.occurred |= SEND;
239 : }
240 : if (FD_ISSET(s, &err)) {
241 : events.occurred |= ERR;
242 : }
243 : }
244 :
245 : return true;
246 : #endif /* USE_POLL */
247 0 : }
248 :
249 0 : void Sock::SendComplete(const std::string& data,
250 : std::chrono::milliseconds timeout,
251 : CThreadInterrupt& interrupt) const
252 : {
253 0 : const auto deadline = GetTime<std::chrono::milliseconds>() + timeout;
254 0 : size_t sent{0};
255 :
256 0 : for (;;) {
257 0 : const ssize_t ret{Send(data.data() + sent, data.size() - sent, MSG_NOSIGNAL)};
258 :
259 0 : if (ret > 0) {
260 0 : sent += static_cast<size_t>(ret);
261 0 : if (sent == data.size()) {
262 0 : break;
263 : }
264 0 : } else {
265 0 : const int err{WSAGetLastError()};
266 0 : if (IOErrorIsPermanent(err)) {
267 0 : throw std::runtime_error(strprintf("send(): %s", NetworkErrorString(err)));
268 : }
269 : }
270 :
271 0 : const auto now = GetTime<std::chrono::milliseconds>();
272 :
273 0 : if (now >= deadline) {
274 0 : throw std::runtime_error(strprintf(
275 0 : "Send timeout (sent only %u of %u bytes before that)", sent, data.size()));
276 : }
277 :
278 0 : if (interrupt) {
279 0 : throw std::runtime_error(strprintf(
280 0 : "Send interrupted (sent only %u of %u bytes before that)", sent, data.size()));
281 : }
282 :
283 : // Wait for a short while (or the socket to become ready for sending) before retrying
284 : // if nothing was sent.
285 0 : const auto wait_time = std::min(deadline - now, std::chrono::milliseconds{MAX_WAIT_FOR_IO});
286 0 : (void)Wait(wait_time, SEND);
287 : }
288 0 : }
289 :
290 0 : std::string Sock::RecvUntilTerminator(uint8_t terminator,
291 : std::chrono::milliseconds timeout,
292 : CThreadInterrupt& interrupt,
293 : size_t max_data) const
294 : {
295 0 : const auto deadline = GetTime<std::chrono::milliseconds>() + timeout;
296 0 : std::string data;
297 0 : bool terminator_found{false};
298 :
299 : // We must not consume any bytes past the terminator from the socket.
300 : // One option is to read one byte at a time and check if we have read a terminator.
301 : // However that is very slow. Instead, we peek at what is in the socket and only read
302 : // as many bytes as possible without crossing the terminator.
303 : // Reading 64 MiB of random data with 262526 terminator chars takes 37 seconds to read
304 : // one byte at a time VS 0.71 seconds with the "peek" solution below. Reading one byte
305 : // at a time is about 50 times slower.
306 :
307 0 : for (;;) {
308 0 : if (data.size() >= max_data) {
309 0 : throw std::runtime_error(
310 0 : strprintf("Received too many bytes without a terminator (%u)", data.size()));
311 : }
312 :
313 : char buf[512];
314 :
315 0 : const ssize_t peek_ret{Recv(buf, std::min(sizeof(buf), max_data - data.size()), MSG_PEEK)};
316 :
317 0 : switch (peek_ret) {
318 : case -1: {
319 0 : const int err{WSAGetLastError()};
320 0 : if (IOErrorIsPermanent(err)) {
321 0 : throw std::runtime_error(strprintf("recv(): %s", NetworkErrorString(err)));
322 : }
323 0 : break;
324 : }
325 : case 0:
326 0 : throw std::runtime_error("Connection unexpectedly closed by peer");
327 : default:
328 0 : auto end = buf + peek_ret;
329 0 : auto terminator_pos = std::find(buf, end, terminator);
330 0 : terminator_found = terminator_pos != end;
331 :
332 0 : const size_t try_len{terminator_found ? terminator_pos - buf + 1 :
333 0 : static_cast<size_t>(peek_ret)};
334 :
335 0 : const ssize_t read_ret{Recv(buf, try_len, 0)};
336 :
337 0 : if (read_ret < 0 || static_cast<size_t>(read_ret) != try_len) {
338 0 : throw std::runtime_error(
339 0 : strprintf("recv() returned %u bytes on attempt to read %u bytes but previous "
340 : "peek claimed %u bytes are available",
341 : read_ret, try_len, peek_ret));
342 : }
343 :
344 : // Don't include the terminator in the output.
345 0 : const size_t append_len{terminator_found ? try_len - 1 : try_len};
346 :
347 0 : data.append(buf, buf + append_len);
348 :
349 0 : if (terminator_found) {
350 0 : return data;
351 : }
352 0 : }
353 :
354 0 : const auto now = GetTime<std::chrono::milliseconds>();
355 :
356 0 : if (now >= deadline) {
357 0 : throw std::runtime_error(strprintf(
358 0 : "Receive timeout (received %u bytes without terminator before that)", data.size()));
359 : }
360 :
361 0 : if (interrupt) {
362 0 : throw std::runtime_error(strprintf(
363 : "Receive interrupted (received %u bytes without terminator before that)",
364 0 : data.size()));
365 : }
366 :
367 : // Wait for a short while (or the socket to become ready for reading) before retrying.
368 0 : const auto wait_time = std::min(deadline - now, std::chrono::milliseconds{MAX_WAIT_FOR_IO});
369 0 : (void)Wait(wait_time, RECV);
370 : }
371 0 : }
372 :
373 0 : bool Sock::IsConnected(std::string& errmsg) const
374 : {
375 0 : if (m_socket == INVALID_SOCKET) {
376 0 : errmsg = "not connected";
377 0 : return false;
378 : }
379 :
380 : char c;
381 0 : switch (Recv(&c, sizeof(c), MSG_PEEK)) {
382 : case -1: {
383 0 : const int err = WSAGetLastError();
384 0 : if (IOErrorIsPermanent(err)) {
385 0 : errmsg = NetworkErrorString(err);
386 0 : return false;
387 : }
388 0 : return true;
389 : }
390 : case 0:
391 0 : errmsg = "closed";
392 0 : return false;
393 : default:
394 0 : return true;
395 : }
396 0 : }
397 :
398 0 : void Sock::Close()
399 : {
400 0 : if (m_socket == INVALID_SOCKET) {
401 0 : return;
402 : }
403 : #ifdef WIN32
404 : int ret = closesocket(m_socket);
405 : #else
406 0 : int ret = close(m_socket);
407 : #endif
408 0 : if (ret) {
409 0 : LogPrintf("Error closing socket %d: %s\n", m_socket, NetworkErrorString(WSAGetLastError()));
410 0 : }
411 0 : m_socket = INVALID_SOCKET;
412 0 : }
413 :
414 0 : std::string NetworkErrorString(int err)
415 : {
416 : #if defined(WIN32)
417 : return Win32ErrorString(err);
418 : #else
419 : // On BSD sockets implementations, NetworkErrorString is the same as SysErrorString.
420 0 : return SysErrorString(err);
421 : #endif
422 : }
|