LCOV - code coverage report
Current view: top level - src/util - sock.cpp (source / functions) Hit Total Coverage
Test: fuzz_coverage.info Lines: 9 182 4.9 %
Date: 2023-10-05 12:38:51 Functions: 5 27 18.5 %
Branches: 2 204 1.0 %

           Branch data     Line data    Source code
       1                 :            : // Copyright (c) 2020-2022 The Bitcoin Core developers
       2                 :            : // Distributed under the MIT software license, see the accompanying
       3                 :            : // file COPYING or http://www.opensource.org/licenses/mit-license.php.
       4                 :            : 
       5                 :            : #include <common/system.h>
       6                 :            : #include <compat/compat.h>
       7                 :            : #include <logging.h>
       8                 :            : #include <tinyformat.h>
       9                 :            : #include <util/sock.h>
      10                 :            : #include <util/syserror.h>
      11                 :            : #include <util/threadinterrupt.h>
      12                 :            : #include <util/time.h>
      13                 :            : 
      14                 :            : #include <memory>
      15                 :            : #include <stdexcept>
      16                 :            : #include <string>
      17                 :            : 
      18                 :            : #ifdef USE_POLL
      19                 :            : #include <poll.h>
      20                 :            : #endif
      21                 :            : 
      22                 :          0 : static inline bool IOErrorIsPermanent(int err)
      23                 :            : {
      24   [ #  #  #  #  :          0 :     return err != WSAEAGAIN && err != WSAEINTR && err != WSAEWOULDBLOCK && err != WSAEINPROGRESS;
                   #  # ]
      25                 :            : }
      26                 :            : 
      27                 :      22407 : Sock::Sock(SOCKET s) : m_socket(s) {}
      28                 :            : 
      29                 :          0 : Sock::Sock(Sock&& other)
      30                 :          0 : {
      31                 :          0 :     m_socket = other.m_socket;
      32                 :          0 :     other.m_socket = INVALID_SOCKET;
      33                 :          0 : }
      34                 :            : 
      35         [ +  - ]:      22407 : Sock::~Sock() { Close(); }
      36                 :            : 
      37                 :          0 : Sock& Sock::operator=(Sock&& other)
      38                 :            : {
      39                 :          0 :     Close();
      40                 :          0 :     m_socket = other.m_socket;
      41                 :          0 :     other.m_socket = INVALID_SOCKET;
      42                 :          0 :     return *this;
      43                 :            : }
      44                 :            : 
      45                 :          0 : ssize_t Sock::Send(const void* data, size_t len, int flags) const
      46                 :            : {
      47                 :          0 :     return send(m_socket, static_cast<const char*>(data), len, flags);
      48                 :            : }
      49                 :            : 
      50                 :          0 : ssize_t Sock::Recv(void* buf, size_t len, int flags) const
      51                 :            : {
      52                 :          0 :     return recv(m_socket, static_cast<char*>(buf), len, flags);
      53                 :            : }
      54                 :            : 
      55                 :          0 : int Sock::Connect(const sockaddr* addr, socklen_t addr_len) const
      56                 :            : {
      57                 :          0 :     return connect(m_socket, addr, addr_len);
      58                 :            : }
      59                 :            : 
      60                 :          0 : int Sock::Bind(const sockaddr* addr, socklen_t addr_len) const
      61                 :            : {
      62                 :          0 :     return bind(m_socket, addr, addr_len);
      63                 :            : }
      64                 :            : 
      65                 :          0 : int Sock::Listen(int backlog) const
      66                 :            : {
      67                 :          0 :     return listen(m_socket, backlog);
      68                 :            : }
      69                 :            : 
      70                 :          0 : std::unique_ptr<Sock> Sock::Accept(sockaddr* addr, socklen_t* addr_len) const
      71                 :            : {
      72                 :            : #ifdef WIN32
      73                 :            :     static constexpr auto ERR = INVALID_SOCKET;
      74                 :          2 : #else
      75                 :            :     static constexpr auto ERR = SOCKET_ERROR;
      76                 :            : #endif
      77                 :            : 
      78                 :          0 :     std::unique_ptr<Sock> sock;
      79                 :            : 
      80         [ #  # ]:          0 :     const auto socket = accept(m_socket, addr, addr_len);
      81                 :          0 :     if (socket != ERR) {
      82                 :            :         try {
      83         [ #  # ]:          0 :             sock = std::make_unique<Sock>(socket);
      84         [ #  # ]:          0 :         } catch (const std::exception&) {
      85                 :            : #ifdef WIN32
      86                 :            :             closesocket(socket);
      87                 :            : #else
      88         [ #  # ]:          0 :             close(socket);
      89                 :            : #endif
      90   [ #  #  #  # ]:          0 :         }
      91                 :          0 :     }
      92                 :            : 
      93                 :          0 :     return sock;
      94                 :          0 : }
      95                 :            : 
      96                 :          0 : int Sock::GetSockOpt(int level, int opt_name, void* opt_val, socklen_t* opt_len) const
      97                 :            : {
      98                 :          0 :     return getsockopt(m_socket, level, opt_name, static_cast<char*>(opt_val), opt_len);
      99                 :            : }
     100                 :            : 
     101                 :          0 : int Sock::SetSockOpt(int level, int opt_name, const void* opt_val, socklen_t opt_len) const
     102                 :            : {
     103                 :          0 :     return setsockopt(m_socket, level, opt_name, static_cast<const char*>(opt_val), opt_len);
     104                 :            : }
     105                 :            : 
     106                 :          0 : int Sock::GetSockName(sockaddr* name, socklen_t* name_len) const
     107                 :            : {
     108                 :          0 :     return getsockname(m_socket, name, name_len);
     109                 :            : }
     110                 :            : 
     111                 :          0 : bool Sock::SetNonBlocking() const
     112                 :            : {
     113                 :            : #ifdef WIN32
     114                 :            :     u_long on{1};
     115                 :            :     if (ioctlsocket(m_socket, FIONBIO, &on) == SOCKET_ERROR) {
     116                 :            :         return false;
     117                 :            :     }
     118                 :            : #else
     119                 :          0 :     const int flags{fcntl(m_socket, F_GETFL, 0)};
     120         [ #  # ]:          0 :     if (flags == SOCKET_ERROR) {
     121                 :          0 :         return false;
     122                 :            :     }
     123         [ #  # ]:          0 :     if (fcntl(m_socket, F_SETFL, flags | O_NONBLOCK) == SOCKET_ERROR) {
     124                 :          0 :         return false;
     125                 :            :     }
     126                 :            : #endif
     127                 :          0 :     return true;
     128                 :          0 : }
     129                 :            : 
     130                 :          0 : bool Sock::IsSelectable() const
     131                 :            : {
     132                 :            : #if defined(USE_POLL) || defined(WIN32)
     133                 :          0 :     return true;
     134                 :            : #else
     135                 :            :     return m_socket < FD_SETSIZE;
     136                 :            : #endif
     137                 :            : }
     138                 :            : 
     139                 :          0 : bool Sock::Wait(std::chrono::milliseconds timeout, Event requested, Event* occurred) const
     140                 :            : {
     141                 :            :     // We need a `shared_ptr` owning `this` for `WaitMany()`, but don't want
     142                 :            :     // `this` to be destroyed when the `shared_ptr` goes out of scope at the
     143                 :            :     // end of this function. Create it with a custom noop deleter.
     144                 :          0 :     std::shared_ptr<const Sock> shared{this, [](const Sock*) {}};
     145                 :            : 
     146   [ #  #  #  #  :          0 :     EventsPerSock events_per_sock{std::make_pair(shared, Events{requested})};
          #  #  #  #  #  
                      # ]
     147                 :            : 
     148   [ #  #  #  # ]:          0 :     if (!WaitMany(timeout, events_per_sock)) {
     149                 :          0 :         return false;
     150                 :            :     }
     151                 :            : 
     152         [ #  # ]:          0 :     if (occurred != nullptr) {
     153                 :          0 :         *occurred = events_per_sock.begin()->second.occurred;
     154                 :          0 :     }
     155                 :            : 
     156                 :          0 :     return true;
     157                 :          0 : }
     158                 :            : 
     159                 :          0 : bool Sock::WaitMany(std::chrono::milliseconds timeout, EventsPerSock& events_per_sock) const
     160                 :            : {
     161                 :            : #ifdef USE_POLL
     162                 :          0 :     std::vector<pollfd> pfds;
     163         [ #  # ]:          0 :     for (const auto& [sock, events] : events_per_sock) {
     164         [ #  # ]:          0 :         pfds.emplace_back();
     165                 :          0 :         auto& pfd = pfds.back();
     166                 :          0 :         pfd.fd = sock->m_socket;
     167         [ #  # ]:          0 :         if (events.requested & RECV) {
     168                 :          0 :             pfd.events |= POLLIN;
     169                 :          0 :         }
     170         [ #  # ]:          0 :         if (events.requested & SEND) {
     171                 :          0 :             pfd.events |= POLLOUT;
     172                 :          0 :         }
     173                 :            :     }
     174                 :            : 
     175   [ #  #  #  #  :          0 :     if (poll(pfds.data(), pfds.size(), count_milliseconds(timeout)) == SOCKET_ERROR) {
                   #  # ]
     176                 :          0 :         return false;
     177                 :            :     }
     178                 :            : 
     179         [ #  # ]:          0 :     assert(pfds.size() == events_per_sock.size());
     180                 :          0 :     size_t i{0};
     181         [ #  # ]:          0 :     for (auto& [sock, events] : events_per_sock) {
     182         [ #  # ]:          0 :         assert(sock->m_socket == static_cast<SOCKET>(pfds[i].fd));
     183                 :          0 :         events.occurred = 0;
     184         [ #  # ]:          0 :         if (pfds[i].revents & POLLIN) {
     185                 :          0 :             events.occurred |= RECV;
     186                 :          0 :         }
     187         [ #  # ]:          0 :         if (pfds[i].revents & POLLOUT) {
     188                 :          0 :             events.occurred |= SEND;
     189                 :          0 :         }
     190         [ #  # ]:          0 :         if (pfds[i].revents & (POLLERR | POLLHUP)) {
     191                 :          0 :             events.occurred |= ERR;
     192                 :          0 :         }
     193                 :          0 :         ++i;
     194                 :            :     }
     195                 :            : 
     196                 :          0 :     return true;
     197                 :            : #else
     198                 :            :     fd_set recv;
     199                 :            :     fd_set send;
     200                 :            :     fd_set err;
     201                 :            :     FD_ZERO(&recv);
     202                 :            :     FD_ZERO(&send);
     203                 :            :     FD_ZERO(&err);
     204                 :            :     SOCKET socket_max{0};
     205                 :            : 
     206                 :            :     for (const auto& [sock, events] : events_per_sock) {
     207                 :            :         if (!sock->IsSelectable()) {
     208                 :            :             return false;
     209                 :            :         }
     210                 :            :         const auto& s = sock->m_socket;
     211                 :            :         if (events.requested & RECV) {
     212                 :            :             FD_SET(s, &recv);
     213                 :            :         }
     214                 :            :         if (events.requested & SEND) {
     215                 :            :             FD_SET(s, &send);
     216                 :            :         }
     217                 :            :         FD_SET(s, &err);
     218                 :            :         socket_max = std::max(socket_max, s);
     219                 :            :     }
     220                 :            : 
     221                 :            :     timeval tv = MillisToTimeval(timeout);
     222                 :            : 
     223                 :            :     if (select(socket_max + 1, &recv, &send, &err, &tv) == SOCKET_ERROR) {
     224                 :            :         return false;
     225                 :            :     }
     226                 :            : 
     227                 :            :     for (auto& [sock, events] : events_per_sock) {
     228                 :            :         const auto& s = sock->m_socket;
     229                 :            :         events.occurred = 0;
     230                 :            :         if (FD_ISSET(s, &recv)) {
     231                 :            :             events.occurred |= RECV;
     232                 :            :         }
     233                 :            :         if (FD_ISSET(s, &send)) {
     234                 :            :             events.occurred |= SEND;
     235                 :            :         }
     236                 :            :         if (FD_ISSET(s, &err)) {
     237                 :            :             events.occurred |= ERR;
     238                 :            :         }
     239                 :            :     }
     240                 :            : 
     241                 :            :     return true;
     242                 :            : #endif /* USE_POLL */
     243                 :          0 : }
     244                 :            : 
     245                 :          0 : void Sock::SendComplete(const std::string& data,
     246                 :            :                         std::chrono::milliseconds timeout,
     247                 :            :                         CThreadInterrupt& interrupt) const
     248                 :            : {
     249                 :          0 :     const auto deadline = GetTime<std::chrono::milliseconds>() + timeout;
     250                 :          0 :     size_t sent{0};
     251                 :            : 
     252                 :          0 :     for (;;) {
     253                 :          0 :         const ssize_t ret{Send(data.data() + sent, data.size() - sent, MSG_NOSIGNAL)};
     254                 :            : 
     255         [ #  # ]:          0 :         if (ret > 0) {
     256                 :          0 :             sent += static_cast<size_t>(ret);
     257         [ #  # ]:          0 :             if (sent == data.size()) {
     258                 :          0 :                 break;
     259                 :            :             }
     260                 :          0 :         } else {
     261                 :          0 :             const int err{WSAGetLastError()};
     262         [ #  # ]:          0 :             if (IOErrorIsPermanent(err)) {
     263   [ #  #  #  #  :          0 :                 throw std::runtime_error(strprintf("send(): %s", NetworkErrorString(err)));
             #  #  #  # ]
     264                 :            :             }
     265                 :            :         }
     266                 :            : 
     267                 :          0 :         const auto now = GetTime<std::chrono::milliseconds>();
     268                 :            : 
     269         [ #  # ]:          0 :         if (now >= deadline) {
     270   [ #  #  #  #  :          0 :             throw std::runtime_error(strprintf(
             #  #  #  # ]
     271                 :          0 :                 "Send timeout (sent only %u of %u bytes before that)", sent, data.size()));
     272                 :            :         }
     273                 :            : 
     274         [ #  # ]:          0 :         if (interrupt) {
     275   [ #  #  #  #  :          0 :             throw std::runtime_error(strprintf(
             #  #  #  # ]
     276                 :          0 :                 "Send interrupted (sent only %u of %u bytes before that)", sent, data.size()));
     277                 :            :         }
     278                 :            : 
     279                 :            :         // Wait for a short while (or the socket to become ready for sending) before retrying
     280                 :            :         // if nothing was sent.
     281                 :          0 :         const auto wait_time = std::min(deadline - now, std::chrono::milliseconds{MAX_WAIT_FOR_IO});
     282                 :          0 :         (void)Wait(wait_time, SEND);
     283                 :            :     }
     284                 :          0 : }
     285                 :            : 
     286                 :          0 : std::string Sock::RecvUntilTerminator(uint8_t terminator,
     287                 :            :                                       std::chrono::milliseconds timeout,
     288                 :            :                                       CThreadInterrupt& interrupt,
     289                 :            :                                       size_t max_data) const
     290                 :            : {
     291                 :          0 :     const auto deadline = GetTime<std::chrono::milliseconds>() + timeout;
     292                 :          0 :     std::string data;
     293                 :          0 :     bool terminator_found{false};
     294                 :            : 
     295                 :            :     // We must not consume any bytes past the terminator from the socket.
     296                 :            :     // One option is to read one byte at a time and check if we have read a terminator.
     297                 :            :     // However that is very slow. Instead, we peek at what is in the socket and only read
     298                 :            :     // as many bytes as possible without crossing the terminator.
     299                 :            :     // Reading 64 MiB of random data with 262526 terminator chars takes 37 seconds to read
     300                 :            :     // one byte at a time VS 0.71 seconds with the "peek" solution below. Reading one byte
     301                 :            :     // at a time is about 50 times slower.
     302                 :            : 
     303                 :          0 :     for (;;) {
     304         [ #  # ]:          0 :         if (data.size() >= max_data) {
     305   [ #  #  #  # ]:          0 :             throw std::runtime_error(
     306         [ #  # ]:          0 :                 strprintf("Received too many bytes without a terminator (%u)", data.size()));
     307                 :            :         }
     308                 :            : 
     309                 :            :         char buf[512];
     310                 :            : 
     311   [ #  #  #  # ]:          0 :         const ssize_t peek_ret{Recv(buf, std::min(sizeof(buf), max_data - data.size()), MSG_PEEK)};
     312                 :            : 
     313      [ #  #  # ]:          0 :         switch (peek_ret) {
     314                 :            :         case -1: {
     315                 :          0 :             const int err{WSAGetLastError()};
     316   [ #  #  #  # ]:          0 :             if (IOErrorIsPermanent(err)) {
     317   [ #  #  #  #  :          0 :                 throw std::runtime_error(strprintf("recv(): %s", NetworkErrorString(err)));
             #  #  #  # ]
     318                 :            :             }
     319                 :          0 :             break;
     320                 :            :         }
     321                 :            :         case 0:
     322         [ #  # ]:          0 :             throw std::runtime_error("Connection unexpectedly closed by peer");
     323                 :            :         default:
     324                 :          0 :             auto end = buf + peek_ret;
     325         [ #  # ]:          0 :             auto terminator_pos = std::find(buf, end, terminator);
     326                 :          0 :             terminator_found = terminator_pos != end;
     327                 :            : 
     328         [ #  # ]:          0 :             const size_t try_len{terminator_found ? terminator_pos - buf + 1 :
     329                 :          0 :                                                     static_cast<size_t>(peek_ret)};
     330                 :            : 
     331         [ #  # ]:          0 :             const ssize_t read_ret{Recv(buf, try_len, 0)};
     332                 :            : 
     333         [ #  # ]:          0 :             if (read_ret < 0 || static_cast<size_t>(read_ret) != try_len) {
     334   [ #  #  #  # ]:          0 :                 throw std::runtime_error(
     335         [ #  # ]:          0 :                     strprintf("recv() returned %u bytes on attempt to read %u bytes but previous "
     336                 :            :                               "peek claimed %u bytes are available",
     337                 :            :                               read_ret, try_len, peek_ret));
     338                 :            :             }
     339                 :            : 
     340                 :            :             // Don't include the terminator in the output.
     341         [ #  # ]:          0 :             const size_t append_len{terminator_found ? try_len - 1 : try_len};
     342                 :            : 
     343         [ #  # ]:          0 :             data.append(buf, buf + append_len);
     344                 :            : 
     345         [ #  # ]:          0 :             if (terminator_found) {
     346                 :          0 :                 return data;
     347                 :            :             }
     348                 :          0 :         }
     349                 :            : 
     350         [ #  # ]:          0 :         const auto now = GetTime<std::chrono::milliseconds>();
     351                 :            : 
     352   [ #  #  #  # ]:          0 :         if (now >= deadline) {
     353   [ #  #  #  #  :          0 :             throw std::runtime_error(strprintf(
             #  #  #  # ]
     354                 :          0 :                 "Receive timeout (received %u bytes without terminator before that)", data.size()));
     355                 :            :         }
     356                 :            : 
     357   [ #  #  #  # ]:          0 :         if (interrupt) {
     358   [ #  #  #  #  :          0 :             throw std::runtime_error(strprintf(
             #  #  #  # ]
     359                 :            :                 "Receive interrupted (received %u bytes without terminator before that)",
     360                 :          0 :                 data.size()));
     361                 :            :         }
     362                 :            : 
     363                 :            :         // Wait for a short while (or the socket to become ready for reading) before retrying.
     364   [ #  #  #  #  :          0 :         const auto wait_time = std::min(deadline - now, std::chrono::milliseconds{MAX_WAIT_FOR_IO});
                   #  # ]
     365         [ #  # ]:          0 :         (void)Wait(wait_time, RECV);
     366                 :            :     }
     367         [ #  # ]:          0 : }
     368                 :            : 
     369                 :          0 : bool Sock::IsConnected(std::string& errmsg) const
     370                 :            : {
     371         [ #  # ]:          0 :     if (m_socket == INVALID_SOCKET) {
     372                 :          0 :         errmsg = "not connected";
     373                 :          0 :         return false;
     374                 :            :     }
     375                 :            : 
     376                 :            :     char c;
     377      [ #  #  # ]:          0 :     switch (Recv(&c, sizeof(c), MSG_PEEK)) {
     378                 :            :     case -1: {
     379                 :          0 :         const int err = WSAGetLastError();
     380         [ #  # ]:          0 :         if (IOErrorIsPermanent(err)) {
     381                 :          0 :             errmsg = NetworkErrorString(err);
     382                 :          0 :             return false;
     383                 :            :         }
     384                 :          0 :         return true;
     385                 :            :     }
     386                 :            :     case 0:
     387                 :          0 :         errmsg = "closed";
     388                 :          0 :         return false;
     389                 :            :     default:
     390                 :          0 :         return true;
     391                 :            :     }
     392                 :          0 : }
     393                 :            : 
     394                 :      22407 : void Sock::Close()
     395                 :            : {
     396         [ -  + ]:      22407 :     if (m_socket == INVALID_SOCKET) {
     397                 :      22407 :         return;
     398                 :            :     }
     399                 :            : #ifdef WIN32
     400                 :            :     int ret = closesocket(m_socket);
     401                 :            : #else
     402                 :          0 :     int ret = close(m_socket);
     403                 :            : #endif
     404         [ #  # ]:          0 :     if (ret) {
     405   [ #  #  #  #  :          0 :         LogPrintf("Error closing socket %d: %s\n", m_socket, NetworkErrorString(WSAGetLastError()));
             #  #  #  # ]
     406                 :          0 :     }
     407                 :          0 :     m_socket = INVALID_SOCKET;
     408                 :      22407 : }
     409                 :            : 
     410                 :          0 : bool Sock::operator==(SOCKET s) const
     411                 :            : {
     412                 :          0 :     return m_socket == s;
     413                 :            : };
     414                 :            : 
     415                 :       6266 : std::string NetworkErrorString(int err)
     416                 :            : {
     417                 :            : #if defined(WIN32)
     418                 :            :     return Win32ErrorString(err);
     419                 :            : #else
     420                 :            :     // On BSD sockets implementations, NetworkErrorString is the same as SysErrorString.
     421                 :       6266 :     return SysErrorString(err);
     422                 :            : #endif
     423                 :            : }

Generated by: LCOV version 1.14