Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 46 additions & 30 deletions src/net.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -505,7 +505,16 @@ CNode* CConnman::ConnectNode(CAddress addrConnect, const char *pszDest, bool fCo
if (!addr_bind.IsValid()) {
addr_bind = GetBindAddress(sock->Get());
}
CNode* pnode = new CNode(id, nLocalServices, sock->Release(), addrConnect, CalculateKeyedNetGroup(addrConnect), nonce, addr_bind, pszDest ? pszDest : "", conn_type, /* inbound_onion */ false);
CNode* pnode = new CNode(id,
nLocalServices,
std::move(sock),
addrConnect,
CalculateKeyedNetGroup(addrConnect),
nonce,
addr_bind,
pszDest ? pszDest : "",
conn_type,
/*inbound_onion=*/false);
pnode->AddRef();

// We're making a new connection, harvest entropy from the time (and our peer count)
Expand All @@ -517,11 +526,10 @@ CNode* CConnman::ConnectNode(CAddress addrConnect, const char *pszDest, bool fCo
void CNode::CloseSocketDisconnect()
{
fDisconnect = true;
LOCK(cs_hSocket);
if (hSocket != INVALID_SOCKET)
{
LOCK(m_sock_mutex);
if (m_sock) {
LogPrint(BCLog::NET, "disconnecting peer=%d\n", id);
CloseSocket(hSocket);
m_sock.reset();
}
}

Expand Down Expand Up @@ -799,10 +807,11 @@ size_t CConnman::SocketSendData(CNode& node) const
assert(data.size() > node.nSendOffset);
int nBytes = 0;
{
LOCK(node.cs_hSocket);
if (node.hSocket == INVALID_SOCKET)
LOCK(node.m_sock_mutex);
if (!node.m_sock) {
break;
nBytes = send(node.hSocket, reinterpret_cast<const char*>(data.data()) + node.nSendOffset, data.size() - node.nSendOffset, MSG_NOSIGNAL | MSG_DONTWAIT);
}
nBytes = node.m_sock->Send(reinterpret_cast<const char*>(data.data()) + node.nSendOffset, data.size() - node.nSendOffset, MSG_NOSIGNAL | MSG_DONTWAIT);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably a silly question: LOCK(node.m_sock_mutex); applies even to node.m_sock->Send(...) operation but then here one can return contained sockets out of the function.

I understand that m_sock_mutex guards only m_sock member read & write operations but not necessarily the contained socket. Is that correct or not?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is a good question.

Before this PR the mutex protects the value of the socket variable - that is, to make sure it does not change from e.g. 7 (it is a file descriptor) to INVALID_SOCKET in the middle of constructs like:

if (socket != INVALID_SOCKET) {
    ... do something with the socket, assuming it is valid ...
}

After this PR, it is the same with the exception that the variable is not a bare integer anymore, but a shared_ptr and an empty such pointer is used to designate what was designated before by INVALID_SOCKET.

}
if (nBytes > 0) {
node.m_last_send = GetTime<std::chrono::seconds>();
Expand Down Expand Up @@ -1197,7 +1206,16 @@ void CConnman::CreateNodeFromAcceptedSocket(std::unique_ptr<Sock>&& sock,
}

const bool inbound_onion = std::find(m_onion_binds.begin(), m_onion_binds.end(), addr_bind) != m_onion_binds.end();
CNode* pnode = new CNode(id, nodeServices, sock->Release(), addr, CalculateKeyedNetGroup(addr), nonce, addr_bind, "", ConnectionType::INBOUND, inbound_onion);
CNode* pnode = new CNode(id,
nodeServices,
std::move(sock),
addr,
CalculateKeyedNetGroup(addr),
nonce,
addr_bind,
/*addrNameIn=*/"",
ConnectionType::INBOUND,
inbound_onion);
pnode->AddRef();
pnode->m_permissionFlags = permissionFlags;
pnode->m_prefer_evict = discouraged;
Expand Down Expand Up @@ -1381,17 +1399,18 @@ bool CConnman::GenerateSelectSet(const std::vector<CNode*>& nodes,
select_send = !pnode->vSendMsg.empty();
}

LOCK(pnode->cs_hSocket);
if (pnode->hSocket == INVALID_SOCKET)
LOCK(pnode->m_sock_mutex);
if (!pnode->m_sock) {
continue;
}

error_set.insert(pnode->hSocket);
error_set.insert(pnode->m_sock->Get());
if (select_send) {
send_set.insert(pnode->hSocket);
send_set.insert(pnode->m_sock->Get());
continue;
}
if (select_recv) {
recv_set.insert(pnode->hSocket);
recv_set.insert(pnode->m_sock->Get());
}
}

Expand Down Expand Up @@ -1561,23 +1580,25 @@ void CConnman::SocketHandlerConnected(const std::vector<CNode*>& nodes,
bool sendSet = false;
bool errorSet = false;
{
LOCK(pnode->cs_hSocket);
if (pnode->hSocket == INVALID_SOCKET)
LOCK(pnode->m_sock_mutex);
if (!pnode->m_sock) {
continue;
recvSet = recv_set.count(pnode->hSocket) > 0;
sendSet = send_set.count(pnode->hSocket) > 0;
errorSet = error_set.count(pnode->hSocket) > 0;
}
recvSet = recv_set.count(pnode->m_sock->Get()) > 0;
sendSet = send_set.count(pnode->m_sock->Get()) > 0;
errorSet = error_set.count(pnode->m_sock->Get()) > 0;
}
if (recvSet || errorSet)
{
// typical socket buffer is 8K-64K
uint8_t pchBuf[0x10000];
int nBytes = 0;
{
LOCK(pnode->cs_hSocket);
if (pnode->hSocket == INVALID_SOCKET)
LOCK(pnode->m_sock_mutex);
if (!pnode->m_sock) {
continue;
nBytes = recv(pnode->hSocket, (char*)pchBuf, sizeof(pchBuf), MSG_DONTWAIT);
}
nBytes = pnode->m_sock->Recv(pchBuf, sizeof(pchBuf), MSG_DONTWAIT);
}
if (nBytes > 0)
{
Expand Down Expand Up @@ -2962,8 +2983,9 @@ ServiceFlags CConnman::GetLocalServices() const

unsigned int CConnman::GetReceiveFloodSize() const { return nReceiveFloodSize; }

CNode::CNode(NodeId idIn, ServiceFlags nLocalServicesIn, SOCKET hSocketIn, const CAddress& addrIn, uint64_t nKeyedNetGroupIn, uint64_t nLocalHostNonceIn, const CAddress& addrBindIn, const std::string& addrNameIn, ConnectionType conn_type_in, bool inbound_onion)
: m_connected{GetTime<std::chrono::seconds>()},
CNode::CNode(NodeId idIn, ServiceFlags nLocalServicesIn, std::shared_ptr<Sock> sock, const CAddress& addrIn, uint64_t nKeyedNetGroupIn, uint64_t nLocalHostNonceIn, const CAddress& addrBindIn, const std::string& addrNameIn, ConnectionType conn_type_in, bool inbound_onion)
: m_sock{sock},
m_connected{GetTime<std::chrono::seconds>()},
addr(addrIn),
addrBind(addrBindIn),
m_addr_name{addrNameIn.empty() ? addr.ToStringIPPort() : addrNameIn},
Expand All @@ -2975,7 +2997,6 @@ CNode::CNode(NodeId idIn, ServiceFlags nLocalServicesIn, SOCKET hSocketIn, const
nLocalServices(nLocalServicesIn)
{
if (inbound_onion) assert(conn_type_in == ConnectionType::INBOUND);
hSocket = hSocketIn;
if (conn_type_in != ConnectionType::BLOCK_RELAY) {
m_tx_relay = std::make_unique<TxRelay>();
}
Expand All @@ -2994,11 +3015,6 @@ CNode::CNode(NodeId idIn, ServiceFlags nLocalServicesIn, SOCKET hSocketIn, const
m_serializer = std::make_unique<V1TransportSerializer>(V1TransportSerializer());
}

CNode::~CNode()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apologies if this is a silly question but just wanted to clarify - removing this destructor, it's implicit that the shared_ptr for Sock will be reset and the Sock's destructor will responsible for closing the socket?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See https://en.cppreference.com/w/cpp/language/destructor "Implicitly-declared destructor".

The default destructor for CNode will call all the destructors of all of it's members. The destructor for shared_ptr will reduce the reference counter, and then if that reference counter is zero, the destructor for Sock will be called. If the RC is not zero, then the Sock will continue to exist on the heap (presumably being used by another thread), until the owner of that shared_ptr calls the destructor.

{
CloseSocket(hSocket);
}

bool CConnman::NodeFullyConnected(const CNode* pnode)
{
return pnode && pnode->fSuccessfullyConnected && !pnode->fDisconnect;
Expand Down
17 changes: 13 additions & 4 deletions src/net.h
Original file line number Diff line number Diff line change
Expand Up @@ -402,15 +402,25 @@ class CNode

NetPermissionFlags m_permissionFlags{NetPermissionFlags::None};
std::atomic<ServiceFlags> nServices{NODE_NONE};
SOCKET hSocket GUARDED_BY(cs_hSocket);

/**
* Socket used for communication with the node.
* May not own a Sock object (after `CloseSocketDisconnect()` or during tests).
* `shared_ptr` (instead of `unique_ptr`) is used to avoid premature close of
* the underlying file descriptor by one thread while another thread is
* poll(2)-ing it for activity.
* @see https://github.com/bitcoin/bitcoin/issues/21744 for details.
*/
std::shared_ptr<Sock> m_sock GUARDED_BY(m_sock_mutex);

/** Total size of all vSendMsg entries */
size_t nSendSize GUARDED_BY(cs_vSend){0};
/** Offset inside the first vSendMsg already sent */
size_t nSendOffset GUARDED_BY(cs_vSend){0};
uint64_t nSendBytes GUARDED_BY(cs_vSend){0};
std::deque<std::vector<unsigned char>> vSendMsg GUARDED_BY(cs_vSend);
Mutex cs_vSend;
Mutex cs_hSocket;
Mutex m_sock_mutex;
Mutex cs_vRecv;

RecursiveMutex cs_vProcessMsg;
Expand Down Expand Up @@ -578,8 +588,7 @@ class CNode
* criterium in CConnman::AttemptToEvictConnection. */
std::atomic<std::chrono::microseconds> m_min_ping_time{std::chrono::microseconds::max()};

CNode(NodeId id, ServiceFlags nLocalServicesIn, SOCKET hSocketIn, const CAddress& addrIn, uint64_t nKeyedNetGroupIn, uint64_t nLocalHostNonceIn, const CAddress& addrBindIn, const std::string& addrNameIn, ConnectionType conn_type_in, bool inbound_onion);
~CNode();
CNode(NodeId id, ServiceFlags nLocalServicesIn, std::shared_ptr<Sock> sock, const CAddress& addrIn, uint64_t nKeyedNetGroupIn, uint64_t nLocalHostNonceIn, const CAddress& addrBindIn, const std::string& addrNameIn, ConnectionType conn_type_in, bool inbound_onion);
CNode(const CNode&) = delete;
CNode& operator=(const CNode&) = delete;

Expand Down
72 changes: 60 additions & 12 deletions src/test/denialofservice_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,16 @@ BOOST_AUTO_TEST_CASE(outbound_slow_chain_eviction)

// Mock an outbound peer
CAddress addr1(ip(0xa0b0c001), NODE_NONE);
CNode dummyNode1(id++, ServiceFlags(NODE_NETWORK | NODE_WITNESS), INVALID_SOCKET, addr1, /*nKeyedNetGroupIn=*/0, /*nLocalHostNonceIn=*/0, CAddress(), /*addrNameIn=*/"", ConnectionType::OUTBOUND_FULL_RELAY, /*inbound_onion=*/false);
CNode dummyNode1{id++,
ServiceFlags(NODE_NETWORK | NODE_WITNESS),
/*sock=*/nullptr,
addr1,
/*nKeyedNetGroupIn=*/0,
/*nLocalHostNonceIn=*/0,
CAddress(),
/*addrNameIn=*/"",
ConnectionType::OUTBOUND_FULL_RELAY,
/*inbound_onion=*/false};
dummyNode1.SetCommonVersion(PROTOCOL_VERSION);

peerLogic->InitializeNode(&dummyNode1);
Expand Down Expand Up @@ -108,7 +117,16 @@ BOOST_AUTO_TEST_CASE(outbound_slow_chain_eviction)
static void AddRandomOutboundPeer(std::vector<CNode*>& vNodes, PeerManager& peerLogic, ConnmanTestMsg& connman, ConnectionType connType)
{
CAddress addr(ip(g_insecure_rand_ctx.randbits(32)), NODE_NONE);
vNodes.emplace_back(new CNode(id++, ServiceFlags(NODE_NETWORK | NODE_WITNESS), INVALID_SOCKET, addr, /*nKeyedNetGroupIn=*/0, /*nLocalHostNonceIn=*/0, CAddress(), /*addrNameIn=*/"", connType, /*inbound_onion=*/false));
vNodes.emplace_back(new CNode{id++,
ServiceFlags(NODE_NETWORK | NODE_WITNESS),
/*sock=*/nullptr,
addr,
/*nKeyedNetGroupIn=*/0,
/*nLocalHostNonceIn=*/0,
CAddress(),
/*addrNameIn=*/"",
connType,
/*inbound_onion=*/false});
CNode &node = *vNodes.back();
node.SetCommonVersion(PROTOCOL_VERSION);

Expand Down Expand Up @@ -279,9 +297,16 @@ BOOST_AUTO_TEST_CASE(peer_discouragement)
std::array<CNode*, 3> nodes;

banman->ClearBanned();
nodes[0] = new CNode{id++, NODE_NETWORK, INVALID_SOCKET, addr[0], /*nKeyedNetGroupIn=*/0,
/*nLocalHostNonceIn=*/0, CAddress(), /*addrNameIn=*/"",
ConnectionType::INBOUND, /*inbound_onion=*/false};
nodes[0] = new CNode{id++,
NODE_NETWORK,
/*sock=*/nullptr,
addr[0],
/*nKeyedNetGroupIn=*/0,
/*nLocalHostNonceIn=*/0,
CAddress(),
/*addrNameIn=*/"",
ConnectionType::INBOUND,
/*inbound_onion=*/false};
nodes[0]->SetCommonVersion(PROTOCOL_VERSION);
peerLogic->InitializeNode(nodes[0]);
nodes[0]->fSuccessfullyConnected = true;
Expand All @@ -295,9 +320,16 @@ BOOST_AUTO_TEST_CASE(peer_discouragement)
BOOST_CHECK(nodes[0]->fDisconnect);
BOOST_CHECK(!banman->IsDiscouraged(other_addr)); // Different address, not discouraged

nodes[1] = new CNode{id++, NODE_NETWORK, INVALID_SOCKET, addr[1], /*nKeyedNetGroupIn=*/1,
/*nLocalHostNonceIn=*/1, CAddress(), /*addrNameIn=*/"",
ConnectionType::INBOUND, /*inbound_onion=*/false};
nodes[1] = new CNode{id++,
NODE_NETWORK,
/*sock=*/nullptr,
addr[1],
/*nKeyedNetGroupIn=*/1,
/*nLocalHostNonceIn=*/1,
CAddress(),
/*addrNameIn=*/"",
ConnectionType::INBOUND,
/*inbound_onion=*/false};
nodes[1]->SetCommonVersion(PROTOCOL_VERSION);
peerLogic->InitializeNode(nodes[1]);
nodes[1]->fSuccessfullyConnected = true;
Expand Down Expand Up @@ -326,9 +358,16 @@ BOOST_AUTO_TEST_CASE(peer_discouragement)

// Make sure non-IP peers are discouraged and disconnected properly.

nodes[2] = new CNode{id++, NODE_NETWORK, INVALID_SOCKET, addr[2], /*nKeyedNetGroupIn=*/1,
/*nLocalHostNonceIn=*/1, CAddress(), /*addrNameIn=*/"",
ConnectionType::OUTBOUND_FULL_RELAY, /*inbound_onion=*/false};
nodes[2] = new CNode{id++,
NODE_NETWORK,
/*sock=*/nullptr,
addr[2],
/*nKeyedNetGroupIn=*/1,
/*nLocalHostNonceIn=*/1,
CAddress(),
/*addrNameIn=*/"",
ConnectionType::OUTBOUND_FULL_RELAY,
/*inbound_onion=*/false};
nodes[2]->SetCommonVersion(PROTOCOL_VERSION);
peerLogic->InitializeNode(nodes[2]);
nodes[2]->fSuccessfullyConnected = true;
Expand Down Expand Up @@ -364,7 +403,16 @@ BOOST_AUTO_TEST_CASE(DoS_bantime)
SetMockTime(nStartTime); // Overrides future calls to GetTime()

CAddress addr(ip(0xa0b0c001), NODE_NONE);
CNode dummyNode(id++, NODE_NETWORK, INVALID_SOCKET, addr, /*nKeyedNetGroupIn=*/4, /*nLocalHostNonceIn=*/4, CAddress(), /*addrNameIn=*/"", ConnectionType::INBOUND, /*inbound_onion=*/false);
CNode dummyNode{id++,
NODE_NETWORK,
/*sock=*/nullptr,
addr,
/*nKeyedNetGroupIn=*/4,
/*nLocalHostNonceIn=*/4,
CAddress(),
/*addrNameIn=*/"",
ConnectionType::INBOUND,
/*inbound_onion=*/false};
dummyNode.SetCommonVersion(PROTOCOL_VERSION);
peerLogic->InitializeNode(&dummyNode);
dummyNode.fSuccessfullyConnected = true;
Expand Down
Loading