diff --git a/src/js_udp_wrap.cc b/src/js_udp_wrap.cc index acd956bf00..c4f5356d78 100644 --- a/src/js_udp_wrap.cc +++ b/src/js_udp_wrap.cc @@ -182,7 +182,7 @@ void JSUDPWrap::EmitReceived(const FunctionCallbackInfo& args) { data += avail; len -= avail; wrap->listener()->OnRecv( - avail, buf, reinterpret_cast(&addr), flags); + nullptr, avail, buf, reinterpret_cast(&addr), flags); } } diff --git a/src/node_quic_default_application.cc b/src/node_quic_default_application.cc index 59145fba36..3b862ccf94 100644 --- a/src/node_quic_default_application.cc +++ b/src/node_quic_default_application.cc @@ -207,7 +207,7 @@ bool DefaultApplication::SendStreamData(QuicStream* stream) { Debug(stream, "Sending %" PRIu64 " bytes in serialized packet", nwrite); dest.Realloc(nwrite); - if (!Session()->SendPacket(std::move(dest), &path)) + if (!Session()->SendPacket(std::move(dest), path)) return false; if (IsEmpty(v, c)) { diff --git a/src/node_quic_http3_application.cc b/src/node_quic_http3_application.cc index 940f1b5ba2..944cb83854 100644 --- a/src/node_quic_http3_application.cc +++ b/src/node_quic_http3_application.cc @@ -521,7 +521,7 @@ bool Http3Application::SendPendingData() { Debug(Session(), "Sending %" PRIu64 " bytes in serialized packet", nwrite); dest.Realloc(nwrite); - if (!Session()->SendPacket(std::move(dest), &path)) + if (!Session()->SendPacket(std::move(dest), path)) return false; if (fin) diff --git a/src/node_quic_session.cc b/src/node_quic_session.cc index 5f42a91d2d..f043b594e2 100644 --- a/src/node_quic_session.cc +++ b/src/node_quic_session.cc @@ -1173,7 +1173,8 @@ QuicSession::QuicSession( QuicSessionConfig* config, Local wrap, const ngtcp2_cid* rcid, - const struct sockaddr* addr, + const SocketAddress& local_addr, + const struct sockaddr* remote_addr, const ngtcp2_cid* dcid, const ngtcp2_cid* ocid, uint32_t version, @@ -1193,14 +1194,22 @@ QuicSession::QuicSession( options, QUIC_PREFERRED_ADDRESS_ACCEPT, // Not used on server sessions initial_connection_close) { - InitServer(config, addr, dcid, ocid, version, qlog); + InitServer( + config, + local_addr, + remote_addr, + dcid, + ocid, + version, + qlog); } // Client QuicSession Constructor QuicSession::QuicSession( QuicSocket* socket, v8::Local wrap, - const struct sockaddr* addr, + const SocketAddress& local_addr, + const struct sockaddr* remote_addr, SecureContext* context, Local early_transport_params, Local session_ticket, @@ -1221,7 +1230,13 @@ QuicSession::QuicSession( nullptr, // rcid only used on the server options, select_preferred_address_policy) { - CHECK(InitClient(addr, early_transport_params, session_ticket, dcid, qlog)); + CHECK(InitClient( + local_addr, + remote_addr, + early_transport_params, + session_ticket, + dcid, + qlog)); } // QuicSession is an abstract base class that defines the code used by both @@ -1753,7 +1768,7 @@ void QuicSession::PathValidation( Debug(this, "Path validation succeeded. Updating local and remote addresses"); SetLocalAddress(&path->local); - remote_address_.Update(path->remote.addr, path->remote.addrlen); + UpdateEndpoint(*path); IncrementStat( 1, &session_stats_, &session_stats::path_validation_success_count); @@ -1814,6 +1829,7 @@ bool QuicSession::ReceiveRetry() { bool QuicSession::Receive( ssize_t nread, const uint8_t* data, + const SocketAddress& local_addr, const struct sockaddr* remote_addr, unsigned int flags) { if (IsFlagSet(QUICSESSION_FLAG_DESTROYED)) { @@ -1858,7 +1874,7 @@ bool QuicSession::Receive( // packet to the next so we have to look at the addr on // every packet. remote_address_ = remote_addr; - QuicPath path(Socket()->GetLocalAddress(), &remote_address_); + QuicPath path(local_addr, remote_address_); { // These are within a scope to ensure that the InternalCallbackScope @@ -2191,10 +2207,9 @@ bool QuicSession::SelectPreferredAddress( bool QuicSession::SendPacket( MallocedBuffer buf, - ngtcp2_path_storage* path) { + const ngtcp2_path_storage& path) { sendbuf_.Push(std::move(buf)); - // TODO(@jasnell): Update the local endpoint also? - remote_address_.Update(path->path.remote.addr, path->path.remote.addrlen); + UpdateEndpoint(path.path); return SendPacket("stream data"); } @@ -2242,7 +2257,13 @@ bool QuicSession::SendPacket(const char* diagnostic_label) { Debug(this, "There are %" PRIu64 " bytes in txbuf_ to send", txbuf_.Length()); session_stats_.session_sent_at = uv_hrtime(); ScheduleRetransmit(); + Debug(this, "Sending to %s:%d from %s:%d", + remote_address_.GetAddress().c_str(), + remote_address_.GetPort(), + local_address_.GetAddress().c_str(), + local_address_.GetPort()); int err = Socket()->SendPacket( + local_address_, remote_address_, &txbuf_, BaseObjectPtr(this), @@ -2611,14 +2632,20 @@ bool QuicSession::WritePackets(const char* diagnostic_label) { } data.Realloc(nwrite); - remote_address_.Update(path.path.remote.addr, path.path.remote.addrlen); + UpdateEndpoint(path.path); sendbuf_.Push(std::move(data)); UpdateDataStats(); + if (!SendPacket(diagnostic_label)) return false; } } +void QuicSession::UpdateEndpoint(const ngtcp2_path& path) { + remote_address_.Update(path.remote.addr, path.remote.addrlen); + local_address_.Update(path.local.addr, path.local.addrlen); +} + bool QuicSession::SubmitInformation( int64_t stream_id, v8::Local headers) { @@ -2692,7 +2719,8 @@ BaseObjectPtr QuicSession::CreateServer( QuicSocket* socket, QuicSessionConfig* config, const ngtcp2_cid* rcid, - const struct sockaddr* addr, + const SocketAddress& local_addr, + const struct sockaddr* remote_addr, const ngtcp2_cid* dcid, const ngtcp2_cid* ocid, uint32_t version, @@ -2712,7 +2740,8 @@ BaseObjectPtr QuicSession::CreateServer( config, obj, rcid, - addr, + local_addr, + remote_addr, dcid, ocid, version, @@ -2727,7 +2756,8 @@ BaseObjectPtr QuicSession::CreateServer( void QuicSession::InitServer( QuicSessionConfig* config, - const struct sockaddr* addr, + const SocketAddress& local_addr, + const struct sockaddr* remote_addr, const ngtcp2_cid* dcid, const ngtcp2_cid* ocid, uint32_t version, @@ -2738,8 +2768,9 @@ void QuicSession::InitServer( ExtendMaxStreamsBidi(DEFAULT_MAX_STREAMS_BIDI); ExtendMaxStreamsUni(DEFAULT_MAX_STREAMS_UNI); - remote_address_ = addr; - max_pktlen_ = GetMaxPktLen(addr); + local_address_ = local_addr; + remote_address_ = remote_addr; + max_pktlen_ = GetMaxPktLen(remote_addr); config->SetOriginalConnectionID(ocid); config->GenerateStatelessResetToken(); @@ -2749,7 +2780,7 @@ void QuicSession::InitServer( EntropySource(scid_.data, NGTCP2_SV_SCIDLEN); scid_.datalen = NGTCP2_SV_SCIDLEN; - QuicPath path(Socket()->GetLocalAddress(), &remote_address_); + QuicPath path(local_addr, remote_address_); // NOLINTNEXTLINE(readability/pointer_notation) if (qlog == QlogMode::kEnabled) config->SetQlog({ *ocid, OnQlogWrite }); @@ -2838,6 +2869,7 @@ BaseObjectPtr QuicSession::CreateClient( MakeDetachedBaseObject( socket, obj, + *socket->GetLocalAddress(), addr, context, early_transport_params, @@ -2854,15 +2886,22 @@ BaseObjectPtr QuicSession::CreateClient( } bool QuicSession::InitClient( - const struct sockaddr* addr, + const SocketAddress& local_addr, + const struct sockaddr* remote_addr, Local early_transport_params, Local session_ticket, Local dcid_value, QlogMode qlog) { CHECK_NULL(connection_); - remote_address_ = addr; - max_pktlen_ = GetMaxPktLen(addr); + local_address_ = local_addr; + remote_address_ = remote_addr; + Debug(this, "Initializing connection from %s:%d to %s:%d", + local_address_.GetAddress().c_str(), + local_address_.GetPort(), + remote_address_.GetAddress().c_str(), + remote_address_.GetPort()); + max_pktlen_ = GetMaxPktLen(remote_addr); QuicSessionConfig config(env()); max_crypto_buffer_ = config.GetMaxCryptoBuffer(); @@ -2885,7 +2924,7 @@ bool QuicSession::InitClient( EntropySource(dcid.data, dcid.datalen); } - QuicPath path(Socket()->GetLocalAddress(), &remote_address_); + QuicPath path(local_address_, remote_address_); if (qlog == QlogMode::kEnabled) config.SetQlog({ dcid, OnQlogWrite }); @@ -2902,6 +2941,8 @@ bool QuicSession::InitClient( &alloc_info_, static_cast(this)), 0); + auto n = ngtcp2_conn_get_remote_addr(conn); + connection_.reset(conn); InitializeTLS(this); diff --git a/src/node_quic_session.h b/src/node_quic_session.h index b6a4965e4a..eb56488c58 100644 --- a/src/node_quic_session.h +++ b/src/node_quic_session.h @@ -581,7 +581,8 @@ class QuicSession : public AsyncWrap, QuicSocket* socket, QuicSessionConfig* config, const ngtcp2_cid* rcid, - const struct sockaddr* addr, + const SocketAddress& local_addr, + const struct sockaddr* remote_addr, const ngtcp2_cid* dcid, const ngtcp2_cid* ocid, uint32_t version, @@ -638,7 +639,8 @@ class QuicSession : public AsyncWrap, QuicSessionConfig* config, v8::Local wrap, const ngtcp2_cid* rcid, - const struct sockaddr* addr, + const SocketAddress& local_addr, + const struct sockaddr* remote_addr, const ngtcp2_cid* dcid, const ngtcp2_cid* ocid, uint32_t version, @@ -651,7 +653,8 @@ class QuicSession : public AsyncWrap, QuicSession( QuicSocket* socket, v8::Local wrap, - const struct sockaddr* addr, + const SocketAddress& local_addr, + const struct sockaddr* remote_addr, crypto::SecureContext* context, v8::Local early_transport_params, v8::Local session_ticket, @@ -804,6 +807,7 @@ class QuicSession : public AsyncWrap, bool Receive( ssize_t nread, const uint8_t* data, + const SocketAddress& local_addr, const struct sockaddr* remote_addr, unsigned int flags); @@ -824,7 +828,9 @@ class QuicSession : public AsyncWrap, // Causes pending QuicStream data to be serialized and sent bool SendStreamData(QuicStream* stream); - bool SendPacket(MallocedBuffer buf, ngtcp2_path_storage* path); + bool SendPacket( + MallocedBuffer buf, + const ngtcp2_path_storage& path); inline uint64_t GetMaxDataLeft(); @@ -986,7 +992,8 @@ class QuicSession : public AsyncWrap, // Initialize the QuicSession as a server void InitServer( QuicSessionConfig* config, - const struct sockaddr* addr, + const SocketAddress& local_addr, + const struct sockaddr* remote_addr, const ngtcp2_cid* dcid, const ngtcp2_cid* ocid, uint32_t version, @@ -994,7 +1001,8 @@ class QuicSession : public AsyncWrap, // Initialize the QuicSession as a client bool InitClient( - const struct sockaddr* addr, + const SocketAddress& local_addr, + const struct sockaddr* remote_addr, v8::Local early_transport_params, v8::Local session_ticket, v8::Local dcid, @@ -1040,6 +1048,7 @@ class QuicSession : public AsyncWrap, bool WritePackets(const char* diagnostic_label = nullptr); void UpdateRecoveryStats(); void UpdateDataStats(); + void UpdateEndpoint(const ngtcp2_path& path); void VersionNegotiation( const ngtcp2_pkt_hd* hd, @@ -1303,6 +1312,7 @@ class QuicSession : public AsyncWrap, uint64_t{NGTCP2_NO_ERROR} }; ConnectionPointer connection_; + SocketAddress local_address_; SocketAddress remote_address_; uint32_t flags_ = 0; uint64_t initial_connection_close_ = NGTCP2_NO_ERROR; diff --git a/src/node_quic_socket.cc b/src/node_quic_socket.cc index 4b38ae7820..1c453fa559 100644 --- a/src/node_quic_socket.cc +++ b/src/node_quic_socket.cc @@ -317,6 +317,7 @@ uv_buf_t QuicSocket::OnAlloc(size_t suggested_size) { } void QuicSocket::OnRecv( + uv_udp_t* handle, ssize_t nread, const uv_buf_t& buf_, const struct sockaddr* addr, @@ -332,7 +333,11 @@ void QuicSocket::OnRecv( return; } - Receive(nread, std::move(buf), addr, flags); + SocketAddress local_address; + if (handle != nullptr) + SocketAddress::FromSockName(handle, &local_address); + + Receive(nread, std::move(buf), local_address, addr, flags); } namespace { @@ -349,7 +354,8 @@ bool IsShortHeader( void QuicSocket::Receive( ssize_t nread, AllocatedBuffer buf, - const struct sockaddr* addr, + const SocketAddress& local_addr, + const struct sockaddr* remote_addr, unsigned int flags) { Debug(this, "Receiving %d bytes from the UDP socket.", nread); @@ -421,7 +427,8 @@ void QuicSocket::Receive( scid, nread, data, - addr, + local_addr, + remote_addr, flags); // There are many reasons why a server QuicSession could not be @@ -443,7 +450,7 @@ void QuicSocket::Receive( // Attempt to send a stateless reset. If it fails, we just ignore // TODO(@jasnell): Need to verify that stateless reset is occurring // correctly. Also need to determine how to test. - if (!SendStatelessReset(dcid, addr)) { + if (!SendStatelessReset(dcid, remote_addr)) { IncrementSocketStat( 1, &socket_stats_, &socket_stats::packets_ignored); @@ -469,7 +476,7 @@ void QuicSocket::Receive( // If the packet could not successfully processed for any reason (possibly // due to being malformed or malicious in some way) we ignore it completely. - if (!session->Receive(nread, data, addr, flags)) { + if (!session->Receive(nread, data, local_addr, remote_addr, flags)) { IncrementSocketStat(1, &socket_stats_, &socket_stats::packets_ignored); return; } @@ -682,7 +689,8 @@ BaseObjectPtr QuicSocket::AcceptInitialPacket( const QuicCID& scid, ssize_t nread, const uint8_t* data, - const struct sockaddr* addr, + const SocketAddress& local_addr, + const struct sockaddr* remote_addr, unsigned int flags) { HandleScope handle_scope(env()->isolate()); Context::Scope context_scope(env()->context()); @@ -700,11 +708,11 @@ BaseObjectPtr QuicSocket::AcceptInitialPacket( // acceptable initial packet with the right QUIC version. switch (QuicSession::Accept(&hd, version, data, nread)) { case QuicSession::InitialPacketResult::PACKET_VERSION: - SendVersionNegotiation(version, dcid, scid, addr); + SendVersionNegotiation(version, dcid, scid, remote_addr); return {}; case QuicSession::InitialPacketResult::PACKET_RETRY: Debug(this, "0RTT Packet. Sending retry."); - SendRetry(version, dcid, scid, addr); + SendRetry(version, dcid, scid, remote_addr); return {}; case QuicSession::InitialPacketResult::PACKET_IGNORE: return {}; @@ -722,7 +730,8 @@ BaseObjectPtr QuicSocket::AcceptInitialPacket( // Check to see if the number of connections for this peer has been exceeded. // If the count has been exceeded, shutdown the connection immediately // after the initial keys are installed. - if (GetCurrentSocketAddressCounter(addr) >= max_connections_per_host_) { + if (GetCurrentSocketAddressCounter(remote_addr) >= + max_connections_per_host_) { Debug(this, "Connection count for address exceeded"); initial_connection_close = NGTCP2_SERVER_BUSY; } @@ -743,21 +752,21 @@ BaseObjectPtr QuicSocket::AcceptInitialPacket( // will check to see if the given address is in the validated_addrs_ // LRU cache. If it is, we'll skip the validation step entirely. // The VALIDATE_ADDRESS_LRU option is disable by default. - if (!IsValidatedAddress(addr)) { + if (!IsValidatedAddress(remote_addr)) { Debug(this, "Performing explicit address validation."); if (InvalidRetryToken( env(), &ocid, &hd, - addr, + remote_addr, token_secret_, retry_token_expiration_)) { Debug(this, "A valid retry token was not found. Sending retry."); - SendRetry(version, dcid, scid, addr); + SendRetry(version, dcid, scid, remote_addr); return {}; } Debug(this, "A valid retry token was found. Continuing."); - SetValidatedAddress(addr); + SetValidatedAddress(remote_addr); ocid_ptr = &ocid; } else { Debug(this, "Skipping validation for recently validated address."); @@ -769,7 +778,8 @@ BaseObjectPtr QuicSocket::AcceptInitialPacket( this, &server_session_config_, dcid.cid(), - addr, + local_addr, + remote_addr, scid.cid(), ocid_ptr, version, @@ -837,7 +847,8 @@ ReqWrap* QuicSocket::CreateSendWrap(size_t msg_size) { } int QuicSocket::SendPacket( - const SocketAddress& dest, + const SocketAddress& local_addr, + const SocketAddress& remote_addr, QuicBuffer* buffer, BaseObjectPtr session, const char* diagnostic_label) { @@ -849,8 +860,8 @@ int QuicSocket::SendPacket( return 0; Debug(this, "Sending to %s at port %d", - dest.GetAddress().c_str(), - dest.GetPort()); + remote_addr.GetAddress().c_str(), + remote_addr.GetPort()); // Remaining Length should never be zero at this point CHECK_GT(buffer->RemainingLength(), 0); @@ -872,7 +883,7 @@ int QuicSocket::SendPacket( } last_created_send_wrap_ = nullptr; - int err = udp_->Send(vec.data(), vec.size(), dest.data()); + int err = udp_->Send(vec.data(), vec.size(), remote_addr.data()); Debug(this, "Advancing read head %" PRIu64 " status = %d", total_length, err); diff --git a/src/node_quic_socket.h b/src/node_quic_socket.h index 3a4eabe2ba..851b8a6789 100644 --- a/src/node_quic_socket.h +++ b/src/node_quic_socket.h @@ -137,7 +137,8 @@ class QuicSocket : public AsyncWrap, void ReportSendError( int error); int SendPacket( - const SocketAddress& dest, + const SocketAddress& local_addr, + const SocketAddress& remote_addr, QuicBuffer* buf, BaseObjectPtr session, const char* diagnostic_label = nullptr); @@ -163,7 +164,8 @@ class QuicSocket : public AsyncWrap, // Implementation for UDPWrapListener uv_buf_t OnAlloc(size_t suggested_size) override; - void OnRecv(ssize_t nread, + void OnRecv(uv_udp_t* handle, + ssize_t nread, const uv_buf_t& buf, const sockaddr* addr, unsigned int flags) override; @@ -190,17 +192,18 @@ class QuicSocket : public AsyncWrap, size_t suggested_size, uv_buf_t* buf); - static void OnRecv( - uv_udp_t* handle, - ssize_t nread, - const uv_buf_t* buf, - const struct sockaddr* addr, - unsigned int flags); + // static void OnRecv( + // uv_udp_t* handle, + // ssize_t nread, + // const uv_buf_t* buf, + // const struct sockaddr* addr, + // unsigned int flags); void Receive( ssize_t nread, AllocatedBuffer buf, - const struct sockaddr* addr, + const SocketAddress& local_addr, + const struct sockaddr* remote_addr, unsigned int flags); void SendInitialConnectionClose( @@ -230,7 +233,8 @@ class QuicSocket : public AsyncWrap, const QuicCID& scid, ssize_t nread, const uint8_t* data, - const struct sockaddr* addr, + const SocketAddress& local_addr, + const struct sockaddr* remote_addr, unsigned int flags); void IncrementSocketAddressCounter(const SocketAddress& addr); diff --git a/src/node_quic_util.h b/src/node_quic_util.h index e3a2a157fd..30e156229d 100644 --- a/src/node_quic_util.h +++ b/src/node_quic_util.h @@ -115,12 +115,26 @@ struct QuicPath : public ngtcp2_path { &this->local, local->data(), local->GetLength(), - nullptr); + local); ngtcp2_addr_init( &this->remote, local->data(), remote->GetLength(), - nullptr); + remote); + } + QuicPath( + const SocketAddress& local, + const SocketAddress& remote) { + ngtcp2_addr_init( + &this->local, + local.data(), + local.GetLength(), + const_cast(&local)); + ngtcp2_addr_init( + &this->remote, + remote.data(), + remote.GetLength(), + const_cast(&remote)); } }; diff --git a/src/udp_wrap.cc b/src/udp_wrap.cc index 20810875ee..55e35eb9da 100644 --- a/src/udp_wrap.cc +++ b/src/udp_wrap.cc @@ -702,10 +702,11 @@ void UDPWrap::OnRecv(uv_udp_t* handle, const sockaddr* addr, unsigned int flags) { UDPWrap* wrap = ContainerOf(&UDPWrap::handle_, handle); - wrap->listener()->OnRecv(nread, *buf, addr, flags); + wrap->listener()->OnRecv(handle, nread, *buf, addr, flags); } -void UDPWrap::OnRecv(ssize_t nread, +void UDPWrap::OnRecv(uv_udp_t* handle, + ssize_t nread, const uv_buf_t& buf_, const sockaddr* addr, unsigned int flags) { diff --git a/src/udp_wrap.h b/src/udp_wrap.h index fe2b9885a3..204d02ec45 100644 --- a/src/udp_wrap.h +++ b/src/udp_wrap.h @@ -46,7 +46,8 @@ class UDPListener { // Called right after data is received from the socket, and includes // information about the source address. If `nread` is negative, an error // has occurred, and it represents a libuv error code. - virtual void OnRecv(ssize_t nread, + virtual void OnRecv(uv_udp_t* handle, + ssize_t nread, const uv_buf_t& buf, const sockaddr* addr, unsigned int flags) = 0; @@ -155,7 +156,8 @@ class UDPWrap final : public HandleWrap, // UDPListener implementation uv_buf_t OnAlloc(size_t suggested_size) override; - void OnRecv(ssize_t nread, + void OnRecv(uv_udp_t* handle, + ssize_t nread, const uv_buf_t& buf, const sockaddr* addr, unsigned int flags) override;