diff --git a/pdns/tcpiohandler.cc b/pdns/tcpiohandler.cc index 72c149b79adb..d13e3be149de 100644 --- a/pdns/tcpiohandler.cc +++ b/pdns/tcpiohandler.cc @@ -75,11 +75,13 @@ class OpenSSLSession : public TLSSession std::unique_ptr d_sess; }; +class OpenSSLTLSIOCtx; + class OpenSSLTLSConnection: public TLSConnection { public: /* server side connection */ - OpenSSLTLSConnection(int socket, const struct timeval& timeout, std::shared_ptr feContext): d_feContext(feContext), d_conn(std::unique_ptr(SSL_new(d_feContext->d_tlsCtx.get()), SSL_free)), d_timeout(timeout) + OpenSSLTLSConnection(int socket, const struct timeval& timeout, std::shared_ptr tlsCtx, std::unique_ptr&& conn): d_tlsCtx(std::move(tlsCtx)), d_conn(std::move(conn)), d_timeout(timeout) { d_socket = socket; @@ -99,7 +101,7 @@ class OpenSSLTLSConnection: public TLSConnection } /* client-side connection */ - OpenSSLTLSConnection(const std::string& hostname, bool hostIsAddr, int socket, const struct timeval& timeout, std::shared_ptr& tlsCtx): d_tlsCtx(tlsCtx), d_conn(std::unique_ptr(SSL_new(tlsCtx.get()), SSL_free)), d_hostname(hostname), d_timeout(timeout) + OpenSSLTLSConnection(const std::string& hostname, bool hostIsAddr, int socket, const struct timeval& timeout, std::shared_ptr tlsCtx, std::unique_ptr&& conn): d_tlsCtx(std::move(tlsCtx)), d_conn(std::move(conn)), d_hostname(std::move(hostname)), d_timeout(timeout), d_isClient(true) { d_socket = socket; @@ -286,7 +288,7 @@ class OpenSSLTLSConnection: public TLSConnection IOState tryHandshake() override { - if (!d_feContext) { + if (isClient()) { /* In client mode, the handshake is initiated by the call to SSL_connect() done from connect()/tryConnect(). In blocking mode it does not return before the handshake has been finished, @@ -314,7 +316,7 @@ class OpenSSLTLSConnection: public TLSConnection void doHandshake() override { - if (!d_feContext) { + if (isClient()) { /* we are a client, nothing to do, see the non-blocking version */ return; } @@ -335,7 +337,7 @@ class OpenSSLTLSConnection: public TLSConnection IOState tryWrite(const PacketBuffer& buffer, size_t& pos, size_t toWrite) override { - if (!d_feContext && !d_connected) { + if (isClient() && !d_connected) { if (d_ktls) { /* work-around to get kTLS to be started, as we cannot do that until after the socket has been connected */ SSL_set_fd(d_conn.get(), SSL_get_fd(d_conn.get())); @@ -563,6 +565,11 @@ class OpenSSLTLSConnection: public TLSConnection d_ktls = true; } + [[nodiscard]] bool isClient() const + { + return d_isClient; + } + static void generateConnectionIndexIfNeeded() { auto init = s_initTLSConnIndex.lock(); @@ -588,31 +595,44 @@ class OpenSSLTLSConnection: public TLSConnection static LockGuarded s_initTLSConnIndex; static int s_tlsConnIndex; std::vector> d_tlsSessions; - /* server context */ - std::shared_ptr d_feContext; - /* client context */ - std::shared_ptr d_tlsCtx; + std::shared_ptr d_tlsCtx; // we need to hold a reference to this to make sure that the context exists for as long as the connection, even if a reload happens in the meantime std::unique_ptr d_conn; std::string d_hostname; struct timeval d_timeout; bool d_connected{false}; bool d_ktls{false}; + bool d_isClient{false}; }; LockGuarded OpenSSLTLSConnection::s_initTLSConnIndex{false}; int OpenSSLTLSConnection::s_tlsConnIndex{-1}; -class OpenSSLTLSIOCtx: public TLSCtx +class OpenSSLTLSIOCtx: public TLSCtx, public std::enable_shared_from_this { + struct Private + { + explicit Private() = default; + }; + public: + static std::shared_ptr createServerSideContext(TLSFrontend& frontend) + { + return std::make_shared(frontend, Private()); + } + + static std::shared_ptr createClientSideContext(const TLSContextParameters& params) + { + return std::make_shared(params, Private()); + } + /* server side context */ - OpenSSLTLSIOCtx(TLSFrontend& fe): d_feContext(std::make_shared(fe.d_addr, fe.d_tlsConfig)) + OpenSSLTLSIOCtx(TLSFrontend& frontend, [[maybe_unused]] Private priv): d_feContext(std::make_unique(frontend.d_addr, frontend.d_tlsConfig)) { OpenSSLTLSConnection::generateConnectionIndexIfNeeded(); - d_ticketsKeyRotationDelay = fe.d_tlsConfig.d_ticketsKeyRotationDelay; + d_ticketsKeyRotationDelay = frontend.d_tlsConfig.d_ticketsKeyRotationDelay; - if (fe.d_tlsConfig.d_enableTickets && fe.d_tlsConfig.d_numberOfTicketsKeys > 0) { + if (frontend.d_tlsConfig.d_enableTickets && frontend.d_tlsConfig.d_numberOfTicketsKeys > 0) { /* use our own ticket keys handler so we can rotate them */ #if OPENSSL_VERSION_MAJOR >= 3 SSL_CTX_set_tlsext_ticket_key_evp_cb(d_feContext->d_tlsCtx.get(), &OpenSSLTLSIOCtx::ticketKeyCb); @@ -629,18 +649,18 @@ class OpenSSLTLSIOCtx: public TLSCtx } #endif /* DISABLE_OCSP_STAPLING */ - libssl_set_error_counters_callback(d_feContext->d_tlsCtx, &fe.d_tlsCounters); + libssl_set_error_counters_callback(d_feContext->d_tlsCtx, &frontend.d_tlsCounters); - if (!fe.d_tlsConfig.d_keyLogFile.empty()) { - d_feContext->d_keyLogFile = libssl_set_key_log_file(d_feContext->d_tlsCtx, fe.d_tlsConfig.d_keyLogFile); + if (!frontend.d_tlsConfig.d_keyLogFile.empty()) { + d_feContext->d_keyLogFile = libssl_set_key_log_file(d_feContext->d_tlsCtx, frontend.d_tlsConfig.d_keyLogFile); } try { - if (fe.d_tlsConfig.d_ticketKeyFile.empty()) { + if (frontend.d_tlsConfig.d_ticketKeyFile.empty()) { handleTicketsKeyRotation(time(nullptr)); } else { - OpenSSLTLSIOCtx::loadTicketsKeys(fe.d_tlsConfig.d_ticketKeyFile); + OpenSSLTLSIOCtx::loadTicketsKeys(frontend.d_tlsConfig.d_ticketKeyFile); } } catch (const std::exception& e) { @@ -649,7 +669,7 @@ class OpenSSLTLSIOCtx: public TLSCtx } /* client side context */ - OpenSSLTLSIOCtx(const TLSContextParameters& params) + OpenSSLTLSIOCtx(const TLSContextParameters& params, [[maybe_unused]] Private priv) { int sslOptions = SSL_OP_NO_SSLv2 | @@ -740,6 +760,11 @@ class OpenSSLTLSIOCtx: public TLSCtx #endif } + OpenSSLTLSIOCtx(const OpenSSLTLSIOCtx&) = delete; + OpenSSLTLSIOCtx(OpenSSLTLSIOCtx&&) = delete; + OpenSSLTLSIOCtx& operator=(const OpenSSLTLSIOCtx&) = delete; + OpenSSLTLSIOCtx& operator=(OpenSSLTLSIOCtx&&) = delete; + ~OpenSSLTLSIOCtx() override { d_tlsCtx.reset(); @@ -797,16 +822,24 @@ class OpenSSLTLSIOCtx: public TLSCtx return 1; } + SSL_CTX* getOpenSSLContext() const + { + if (d_feContext) { + return d_feContext->d_tlsCtx.get(); + } + return d_tlsCtx.get(); + } + std::unique_ptr getConnection(int socket, const struct timeval& timeout, time_t now) override { handleTicketsKeyRotation(now); - return std::make_unique(socket, timeout, d_feContext); + return std::make_unique(socket, timeout, shared_from_this(), std::unique_ptr(SSL_new(getOpenSSLContext()), SSL_free)); } std::unique_ptr getClientConnection(const std::string& host, bool hostIsAddr, int socket, const struct timeval& timeout) override { - auto conn = std::make_unique(host, hostIsAddr, socket, timeout, d_tlsCtx); + auto conn = std::make_unique(host, hostIsAddr, socket, timeout, shared_from_this(), std::unique_ptr(SSL_new(getOpenSSLContext()), SSL_free)); if (d_ktls) { conn->enableKTLS(); } @@ -841,24 +874,32 @@ class OpenSSLTLSIOCtx: public TLSCtx return "openssl"; } + bool isServerContext() const + { + return d_feContext != nullptr; + } + bool setALPNProtos(const std::vector>& protos) override { - if (d_feContext && d_feContext->d_tlsCtx) { + auto* openSSLContext = getOpenSSLContext(); + if (openSSLContext == nullptr) { + return false; + } + + if (isServerContext()) { d_alpnProtos = protos; - libssl_set_alpn_select_callback(d_feContext->d_tlsCtx.get(), alpnServerSelectCallback, this); + libssl_set_alpn_select_callback(openSSLContext, alpnServerSelectCallback, this); return true; } - if (d_tlsCtx) { - return libssl_set_alpn_protos(d_tlsCtx.get(), protos); - } - return false; + + return libssl_set_alpn_protos(openSSLContext, protos); } #ifndef DISABLE_NPN bool setNextProtocolSelectCallback(bool(*cb)(unsigned char** out, unsigned char* outlen, const unsigned char* in, unsigned int inlen)) override { d_nextProtocolSelectCallback = cb; - libssl_set_npn_select_callback(d_tlsCtx.get(), npnSelectCallback, this); + libssl_set_npn_select_callback(getOpenSSLContext(), npnSelectCallback, this); return true; } #endif /* DISABLE_NPN */ @@ -910,8 +951,8 @@ class OpenSSLTLSIOCtx: public TLSCtx } std::vector> d_alpnProtos; // store the supported ALPN protocols, so that the server can select based on what the client sent - std::shared_ptr d_feContext{nullptr}; std::shared_ptr d_tlsCtx{nullptr}; // client context, on a server-side the context is stored in d_feContext->d_tlsCtx + std::unique_ptr d_feContext{nullptr}; bool (*d_nextProtocolSelectCallback)(unsigned char** out, unsigned char* outlen, const unsigned char* in, unsigned int inlen){nullptr}; bool d_ktls{false}; }; @@ -1840,7 +1881,7 @@ bool TLSFrontend::setupTLS() #endif /* HAVE_GNUTLS */ #ifdef HAVE_LIBSSL if (d_provider == "openssl") { - newCtx = std::make_shared(*this); + newCtx = OpenSSLTLSIOCtx::createServerSideContext(*this); setupDoTProtocolNegotiation(newCtx); std::atomic_store_explicit(&d_ctx, newCtx, std::memory_order_release); return true; @@ -1848,7 +1889,7 @@ bool TLSFrontend::setupTLS() #endif /* HAVE_LIBSSL */ } #ifdef HAVE_LIBSSL - newCtx = std::make_shared(*this); + newCtx = OpenSSLTLSIOCtx::createServerSideContext(*this); #else /* HAVE_LIBSSL */ #ifdef HAVE_GNUTLS newCtx = std::make_shared(*this); @@ -1873,13 +1914,13 @@ std::shared_ptr getTLSContext(const TLSContextParameters& params) #endif /* HAVE_GNUTLS */ #ifdef HAVE_LIBSSL if (params.d_provider == "openssl") { - return std::make_shared(params); + return OpenSSLTLSIOCtx::createClientSideContext(params); } #endif /* HAVE_LIBSSL */ } #ifdef HAVE_LIBSSL - return std::make_shared(params); + return OpenSSLTLSIOCtx::createClientSideContext(params); #else /* HAVE_LIBSSL */ #ifdef HAVE_GNUTLS return std::make_shared(params);