LCOV - code coverage report
Current view: top level - src/util - sock.cpp (source / functions) Hit Total Coverage
Test: fuzz_coverage.info Lines: 1 181 0.6 %
Date: 2023-09-26 12:08:55 Functions: 1 28 3.6 %

          Line data    Source code
       1             : // Copyright (c) 2020-2022 The Bitcoin Core developers
       2             : // Distributed under the MIT software license, see the accompanying
       3             : // file COPYING or http://www.opensource.org/licenses/mit-license.php.
       4             : 
       5             : #include <common/system.h>
       6             : #include <compat/compat.h>
       7             : #include <logging.h>
       8             : #include <tinyformat.h>
       9             : #include <util/sock.h>
      10             : #include <util/syserror.h>
      11             : #include <util/threadinterrupt.h>
      12             : #include <util/time.h>
      13             : 
      14             : #include <memory>
      15             : #include <stdexcept>
      16             : #include <string>
      17             : 
      18             : #ifdef USE_POLL
      19             : #include <poll.h>
      20             : #endif
      21             : 
      22           0 : static inline bool IOErrorIsPermanent(int err)
      23             : {
      24           0 :     return err != WSAEAGAIN && err != WSAEINTR && err != WSAEWOULDBLOCK && err != WSAEINPROGRESS;
      25             : }
      26             : 
      27           0 : Sock::Sock() : m_socket(INVALID_SOCKET) {}
      28             : 
      29           0 : Sock::Sock(SOCKET s) : m_socket(s) {}
      30             : 
      31           0 : Sock::Sock(Sock&& other)
      32           0 : {
      33           0 :     m_socket = other.m_socket;
      34           0 :     other.m_socket = INVALID_SOCKET;
      35           0 : }
      36             : 
      37           0 : Sock::~Sock() { Close(); }
      38             : 
      39           0 : Sock& Sock::operator=(Sock&& other)
      40             : {
      41           0 :     Close();
      42           0 :     m_socket = other.m_socket;
      43           0 :     other.m_socket = INVALID_SOCKET;
      44           0 :     return *this;
      45             : }
      46             : 
      47           0 : SOCKET Sock::Get() const { return m_socket; }
      48             : 
      49           0 : ssize_t Sock::Send(const void* data, size_t len, int flags) const
      50             : {
      51           0 :     return send(m_socket, static_cast<const char*>(data), len, flags);
      52             : }
      53             : 
      54           0 : ssize_t Sock::Recv(void* buf, size_t len, int flags) const
      55             : {
      56           0 :     return recv(m_socket, static_cast<char*>(buf), len, flags);
      57             : }
      58             : 
      59           0 : int Sock::Connect(const sockaddr* addr, socklen_t addr_len) const
      60             : {
      61           0 :     return connect(m_socket, addr, addr_len);
      62             : }
      63             : 
      64           0 : int Sock::Bind(const sockaddr* addr, socklen_t addr_len) const
      65             : {
      66           0 :     return bind(m_socket, addr, addr_len);
      67             : }
      68             : 
      69           0 : int Sock::Listen(int backlog) const
      70             : {
      71           0 :     return listen(m_socket, backlog);
      72             : }
      73             : 
      74           2 : std::unique_ptr<Sock> Sock::Accept(sockaddr* addr, socklen_t* addr_len) const
      75             : {
      76             : #ifdef WIN32
      77             :     static constexpr auto ERR = INVALID_SOCKET;
      78             : #else
      79             :     static constexpr auto ERR = SOCKET_ERROR;
      80             : #endif
      81             : 
      82           0 :     std::unique_ptr<Sock> sock;
      83             : 
      84           0 :     const auto socket = accept(m_socket, addr, addr_len);
      85           0 :     if (socket != ERR) {
      86             :         try {
      87           0 :             sock = std::make_unique<Sock>(socket);
      88           0 :         } catch (const std::exception&) {
      89             : #ifdef WIN32
      90             :             closesocket(socket);
      91             : #else
      92           0 :             close(socket);
      93             : #endif
      94           0 :         }
      95           0 :     }
      96             : 
      97           0 :     return sock;
      98           0 : }
      99             : 
     100           0 : int Sock::GetSockOpt(int level, int opt_name, void* opt_val, socklen_t* opt_len) const
     101             : {
     102           0 :     return getsockopt(m_socket, level, opt_name, static_cast<char*>(opt_val), opt_len);
     103             : }
     104             : 
     105           0 : int Sock::SetSockOpt(int level, int opt_name, const void* opt_val, socklen_t opt_len) const
     106             : {
     107           0 :     return setsockopt(m_socket, level, opt_name, static_cast<const char*>(opt_val), opt_len);
     108             : }
     109             : 
     110           0 : int Sock::GetSockName(sockaddr* name, socklen_t* name_len) const
     111             : {
     112           0 :     return getsockname(m_socket, name, name_len);
     113             : }
     114             : 
     115           0 : bool Sock::SetNonBlocking() const
     116             : {
     117             : #ifdef WIN32
     118             :     u_long on{1};
     119             :     if (ioctlsocket(m_socket, FIONBIO, &on) == SOCKET_ERROR) {
     120             :         return false;
     121             :     }
     122             : #else
     123           0 :     const int flags{fcntl(m_socket, F_GETFL, 0)};
     124           0 :     if (flags == SOCKET_ERROR) {
     125           0 :         return false;
     126             :     }
     127           0 :     if (fcntl(m_socket, F_SETFL, flags | O_NONBLOCK) == SOCKET_ERROR) {
     128           0 :         return false;
     129             :     }
     130             : #endif
     131           0 :     return true;
     132           0 : }
     133             : 
     134           0 : bool Sock::IsSelectable() const
     135             : {
     136             : #if defined(USE_POLL) || defined(WIN32)
     137           0 :     return true;
     138             : #else
     139             :     return m_socket < FD_SETSIZE;
     140             : #endif
     141             : }
     142             : 
     143           0 : bool Sock::Wait(std::chrono::milliseconds timeout, Event requested, Event* occurred) const
     144             : {
     145             :     // We need a `shared_ptr` owning `this` for `WaitMany()`, but don't want
     146             :     // `this` to be destroyed when the `shared_ptr` goes out of scope at the
     147             :     // end of this function. Create it with a custom noop deleter.
     148           0 :     std::shared_ptr<const Sock> shared{this, [](const Sock*) {}};
     149             : 
     150           0 :     EventsPerSock events_per_sock{std::make_pair(shared, Events{requested})};
     151             : 
     152           0 :     if (!WaitMany(timeout, events_per_sock)) {
     153           0 :         return false;
     154             :     }
     155             : 
     156           0 :     if (occurred != nullptr) {
     157           0 :         *occurred = events_per_sock.begin()->second.occurred;
     158           0 :     }
     159             : 
     160           0 :     return true;
     161           0 : }
     162             : 
     163           0 : bool Sock::WaitMany(std::chrono::milliseconds timeout, EventsPerSock& events_per_sock) const
     164             : {
     165             : #ifdef USE_POLL
     166           0 :     std::vector<pollfd> pfds;
     167           0 :     for (const auto& [sock, events] : events_per_sock) {
     168           0 :         pfds.emplace_back();
     169           0 :         auto& pfd = pfds.back();
     170           0 :         pfd.fd = sock->m_socket;
     171           0 :         if (events.requested & RECV) {
     172           0 :             pfd.events |= POLLIN;
     173           0 :         }
     174           0 :         if (events.requested & SEND) {
     175           0 :             pfd.events |= POLLOUT;
     176           0 :         }
     177             :     }
     178             : 
     179           0 :     if (poll(pfds.data(), pfds.size(), count_milliseconds(timeout)) == SOCKET_ERROR) {
     180           0 :         return false;
     181             :     }
     182             : 
     183           0 :     assert(pfds.size() == events_per_sock.size());
     184           0 :     size_t i{0};
     185           0 :     for (auto& [sock, events] : events_per_sock) {
     186           0 :         assert(sock->m_socket == static_cast<SOCKET>(pfds[i].fd));
     187           0 :         events.occurred = 0;
     188           0 :         if (pfds[i].revents & POLLIN) {
     189           0 :             events.occurred |= RECV;
     190           0 :         }
     191           0 :         if (pfds[i].revents & POLLOUT) {
     192           0 :             events.occurred |= SEND;
     193           0 :         }
     194           0 :         if (pfds[i].revents & (POLLERR | POLLHUP)) {
     195           0 :             events.occurred |= ERR;
     196           0 :         }
     197           0 :         ++i;
     198             :     }
     199             : 
     200           0 :     return true;
     201             : #else
     202             :     fd_set recv;
     203             :     fd_set send;
     204             :     fd_set err;
     205             :     FD_ZERO(&recv);
     206             :     FD_ZERO(&send);
     207             :     FD_ZERO(&err);
     208             :     SOCKET socket_max{0};
     209             : 
     210             :     for (const auto& [sock, events] : events_per_sock) {
     211             :         if (!sock->IsSelectable()) {
     212             :             return false;
     213             :         }
     214             :         const auto& s = sock->m_socket;
     215             :         if (events.requested & RECV) {
     216             :             FD_SET(s, &recv);
     217             :         }
     218             :         if (events.requested & SEND) {
     219             :             FD_SET(s, &send);
     220             :         }
     221             :         FD_SET(s, &err);
     222             :         socket_max = std::max(socket_max, s);
     223             :     }
     224             : 
     225             :     timeval tv = MillisToTimeval(timeout);
     226             : 
     227             :     if (select(socket_max + 1, &recv, &send, &err, &tv) == SOCKET_ERROR) {
     228             :         return false;
     229             :     }
     230             : 
     231             :     for (auto& [sock, events] : events_per_sock) {
     232             :         const auto& s = sock->m_socket;
     233             :         events.occurred = 0;
     234             :         if (FD_ISSET(s, &recv)) {
     235             :             events.occurred |= RECV;
     236             :         }
     237             :         if (FD_ISSET(s, &send)) {
     238             :             events.occurred |= SEND;
     239             :         }
     240             :         if (FD_ISSET(s, &err)) {
     241             :             events.occurred |= ERR;
     242             :         }
     243             :     }
     244             : 
     245             :     return true;
     246             : #endif /* USE_POLL */
     247           0 : }
     248             : 
     249           0 : void Sock::SendComplete(const std::string& data,
     250             :                         std::chrono::milliseconds timeout,
     251             :                         CThreadInterrupt& interrupt) const
     252             : {
     253           0 :     const auto deadline = GetTime<std::chrono::milliseconds>() + timeout;
     254           0 :     size_t sent{0};
     255             : 
     256           0 :     for (;;) {
     257           0 :         const ssize_t ret{Send(data.data() + sent, data.size() - sent, MSG_NOSIGNAL)};
     258             : 
     259           0 :         if (ret > 0) {
     260           0 :             sent += static_cast<size_t>(ret);
     261           0 :             if (sent == data.size()) {
     262           0 :                 break;
     263             :             }
     264           0 :         } else {
     265           0 :             const int err{WSAGetLastError()};
     266           0 :             if (IOErrorIsPermanent(err)) {
     267           0 :                 throw std::runtime_error(strprintf("send(): %s", NetworkErrorString(err)));
     268             :             }
     269             :         }
     270             : 
     271           0 :         const auto now = GetTime<std::chrono::milliseconds>();
     272             : 
     273           0 :         if (now >= deadline) {
     274           0 :             throw std::runtime_error(strprintf(
     275           0 :                 "Send timeout (sent only %u of %u bytes before that)", sent, data.size()));
     276             :         }
     277             : 
     278           0 :         if (interrupt) {
     279           0 :             throw std::runtime_error(strprintf(
     280           0 :                 "Send interrupted (sent only %u of %u bytes before that)", sent, data.size()));
     281             :         }
     282             : 
     283             :         // Wait for a short while (or the socket to become ready for sending) before retrying
     284             :         // if nothing was sent.
     285           0 :         const auto wait_time = std::min(deadline - now, std::chrono::milliseconds{MAX_WAIT_FOR_IO});
     286           0 :         (void)Wait(wait_time, SEND);
     287             :     }
     288           0 : }
     289             : 
     290           0 : std::string Sock::RecvUntilTerminator(uint8_t terminator,
     291             :                                       std::chrono::milliseconds timeout,
     292             :                                       CThreadInterrupt& interrupt,
     293             :                                       size_t max_data) const
     294             : {
     295           0 :     const auto deadline = GetTime<std::chrono::milliseconds>() + timeout;
     296           0 :     std::string data;
     297           0 :     bool terminator_found{false};
     298             : 
     299             :     // We must not consume any bytes past the terminator from the socket.
     300             :     // One option is to read one byte at a time and check if we have read a terminator.
     301             :     // However that is very slow. Instead, we peek at what is in the socket and only read
     302             :     // as many bytes as possible without crossing the terminator.
     303             :     // Reading 64 MiB of random data with 262526 terminator chars takes 37 seconds to read
     304             :     // one byte at a time VS 0.71 seconds with the "peek" solution below. Reading one byte
     305             :     // at a time is about 50 times slower.
     306             : 
     307           0 :     for (;;) {
     308           0 :         if (data.size() >= max_data) {
     309           0 :             throw std::runtime_error(
     310           0 :                 strprintf("Received too many bytes without a terminator (%u)", data.size()));
     311             :         }
     312             : 
     313             :         char buf[512];
     314             : 
     315           0 :         const ssize_t peek_ret{Recv(buf, std::min(sizeof(buf), max_data - data.size()), MSG_PEEK)};
     316             : 
     317           0 :         switch (peek_ret) {
     318             :         case -1: {
     319           0 :             const int err{WSAGetLastError()};
     320           0 :             if (IOErrorIsPermanent(err)) {
     321           0 :                 throw std::runtime_error(strprintf("recv(): %s", NetworkErrorString(err)));
     322             :             }
     323           0 :             break;
     324             :         }
     325             :         case 0:
     326           0 :             throw std::runtime_error("Connection unexpectedly closed by peer");
     327             :         default:
     328           0 :             auto end = buf + peek_ret;
     329           0 :             auto terminator_pos = std::find(buf, end, terminator);
     330           0 :             terminator_found = terminator_pos != end;
     331             : 
     332           0 :             const size_t try_len{terminator_found ? terminator_pos - buf + 1 :
     333           0 :                                                     static_cast<size_t>(peek_ret)};
     334             : 
     335           0 :             const ssize_t read_ret{Recv(buf, try_len, 0)};
     336             : 
     337           0 :             if (read_ret < 0 || static_cast<size_t>(read_ret) != try_len) {
     338           0 :                 throw std::runtime_error(
     339           0 :                     strprintf("recv() returned %u bytes on attempt to read %u bytes but previous "
     340             :                               "peek claimed %u bytes are available",
     341             :                               read_ret, try_len, peek_ret));
     342             :             }
     343             : 
     344             :             // Don't include the terminator in the output.
     345           0 :             const size_t append_len{terminator_found ? try_len - 1 : try_len};
     346             : 
     347           0 :             data.append(buf, buf + append_len);
     348             : 
     349           0 :             if (terminator_found) {
     350           0 :                 return data;
     351             :             }
     352           0 :         }
     353             : 
     354           0 :         const auto now = GetTime<std::chrono::milliseconds>();
     355             : 
     356           0 :         if (now >= deadline) {
     357           0 :             throw std::runtime_error(strprintf(
     358           0 :                 "Receive timeout (received %u bytes without terminator before that)", data.size()));
     359             :         }
     360             : 
     361           0 :         if (interrupt) {
     362           0 :             throw std::runtime_error(strprintf(
     363             :                 "Receive interrupted (received %u bytes without terminator before that)",
     364           0 :                 data.size()));
     365             :         }
     366             : 
     367             :         // Wait for a short while (or the socket to become ready for reading) before retrying.
     368           0 :         const auto wait_time = std::min(deadline - now, std::chrono::milliseconds{MAX_WAIT_FOR_IO});
     369           0 :         (void)Wait(wait_time, RECV);
     370             :     }
     371           0 : }
     372             : 
     373           0 : bool Sock::IsConnected(std::string& errmsg) const
     374             : {
     375           0 :     if (m_socket == INVALID_SOCKET) {
     376           0 :         errmsg = "not connected";
     377           0 :         return false;
     378             :     }
     379             : 
     380             :     char c;
     381           0 :     switch (Recv(&c, sizeof(c), MSG_PEEK)) {
     382             :     case -1: {
     383           0 :         const int err = WSAGetLastError();
     384           0 :         if (IOErrorIsPermanent(err)) {
     385           0 :             errmsg = NetworkErrorString(err);
     386           0 :             return false;
     387             :         }
     388           0 :         return true;
     389             :     }
     390             :     case 0:
     391           0 :         errmsg = "closed";
     392           0 :         return false;
     393             :     default:
     394           0 :         return true;
     395             :     }
     396           0 : }
     397             : 
     398           0 : void Sock::Close()
     399             : {
     400           0 :     if (m_socket == INVALID_SOCKET) {
     401           0 :         return;
     402             :     }
     403             : #ifdef WIN32
     404             :     int ret = closesocket(m_socket);
     405             : #else
     406           0 :     int ret = close(m_socket);
     407             : #endif
     408           0 :     if (ret) {
     409           0 :         LogPrintf("Error closing socket %d: %s\n", m_socket, NetworkErrorString(WSAGetLastError()));
     410           0 :     }
     411           0 :     m_socket = INVALID_SOCKET;
     412           0 : }
     413             : 
     414           0 : std::string NetworkErrorString(int err)
     415             : {
     416             : #if defined(WIN32)
     417             :     return Win32ErrorString(err);
     418             : #else
     419             :     // On BSD sockets implementations, NetworkErrorString is the same as SysErrorString.
     420           0 :     return SysErrorString(err);
     421             : #endif
     422             : }

Generated by: LCOV version 1.14