Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dnsdist-1.9.x: Backport 14671 - Always store the OpenSSLTLSIOCtx in the connection #14677

Merged
merged 3 commits into from
Sep 23, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 75 additions & 34 deletions pdns/tcpiohandler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,13 @@ class OpenSSLSession : public TLSSession
std::unique_ptr<SSL_SESSION, void(*)(SSL_SESSION*)> d_sess;
};

class OpenSSLTLSIOCtx;

class OpenSSLTLSConnection: public TLSConnection
{
public:
/* server side connection */
OpenSSLTLSConnection(int socket, const struct timeval& timeout, std::shared_ptr<OpenSSLFrontendContext> feContext): d_feContext(std::move(feContext)), d_conn(std::unique_ptr<SSL, void(*)(SSL*)>(SSL_new(d_feContext->d_tlsCtx.get()), SSL_free)), d_timeout(timeout)
OpenSSLTLSConnection(int socket, const struct timeval& timeout, std::shared_ptr<const OpenSSLTLSIOCtx> tlsCtx, std::unique_ptr<SSL, void(*)(SSL*)>&& conn): d_tlsCtx(std::move(tlsCtx)), d_conn(std::move(conn)), d_timeout(timeout)
{
d_socket = socket;

Expand All @@ -97,7 +99,7 @@ class OpenSSLTLSConnection: public TLSConnection
}

/* client-side connection */
OpenSSLTLSConnection(const std::string& hostname, bool hostIsAddr, int socket, const struct timeval& timeout, std::shared_ptr<SSL_CTX>& tlsCtx): d_tlsCtx(tlsCtx), d_conn(std::unique_ptr<SSL, void(*)(SSL*)>(SSL_new(tlsCtx.get()), SSL_free)), d_hostname(hostname), d_timeout(timeout)
OpenSSLTLSConnection(std::string hostname, bool hostIsAddr, int socket, const struct timeval& timeout, std::shared_ptr<const OpenSSLTLSIOCtx> tlsCtx, std::unique_ptr<SSL, void(*)(SSL*)>&& conn): d_tlsCtx(std::move(tlsCtx)), d_conn(std::move(conn)), d_hostname(std::move(hostname)), d_timeout(timeout), d_isClient(true)
{
d_socket = socket;

Expand Down Expand Up @@ -284,7 +286,7 @@ class OpenSSLTLSConnection: public TLSConnection

IOState tryHandshake() override
{
if (!d_feContext) {
if (isClient()) {
/* In client mode, the handshake is initiated by the call to SSL_connect()
done from connect()/tryConnect().
In blocking mode it does not return before the handshake has been finished,
Expand Down Expand Up @@ -312,7 +314,7 @@ class OpenSSLTLSConnection: public TLSConnection

void doHandshake() override
{
if (!d_feContext) {
if (isClient()) {
/* we are a client, nothing to do, see the non-blocking version */
return;
}
Expand All @@ -333,7 +335,7 @@ class OpenSSLTLSConnection: public TLSConnection

IOState tryWrite(const PacketBuffer& buffer, size_t& pos, size_t toWrite) override
{
if (!d_feContext && !d_connected) {
if (isClient() && !d_connected) {
if (d_ktls) {
/* work-around to get kTLS to be started, as we cannot do that until after the socket has been connected */
SSL_set_fd(d_conn.get(), SSL_get_fd(d_conn.get()));
Expand Down Expand Up @@ -552,6 +554,11 @@ class OpenSSLTLSConnection: public TLSConnection
d_ktls = true;
}

[[nodiscard]] bool isClient() const
{
return d_isClient;
}

static void generateConnectionIndexIfNeeded()
{
auto init = s_initTLSConnIndex.lock();
Expand All @@ -577,31 +584,44 @@ class OpenSSLTLSConnection: public TLSConnection
static LockGuarded<bool> s_initTLSConnIndex;
static int s_tlsConnIndex;
std::vector<std::unique_ptr<TLSSession>> d_tlsSessions;
/* server context */
std::shared_ptr<OpenSSLFrontendContext> d_feContext;
/* client context */
std::shared_ptr<SSL_CTX> d_tlsCtx;
std::shared_ptr<const OpenSSLTLSIOCtx> d_tlsCtx; // we need to hold a reference to this to make sure that the context exists for as long as the connection, even if a reload happens in the meantime
std::unique_ptr<SSL, void(*)(SSL*)> d_conn;
std::string d_hostname;
struct timeval d_timeout;
bool d_connected{false};
bool d_ktls{false};
bool d_isClient{false};
};

LockGuarded<bool> OpenSSLTLSConnection::s_initTLSConnIndex{false};
int OpenSSLTLSConnection::s_tlsConnIndex{-1};

class OpenSSLTLSIOCtx: public TLSCtx
class OpenSSLTLSIOCtx: public TLSCtx, public std::enable_shared_from_this<OpenSSLTLSIOCtx>
{
struct Private
{
explicit Private() = default;
};

public:
static std::shared_ptr<OpenSSLTLSIOCtx> createServerSideContext(TLSFrontend& frontend)
{
return std::make_shared<OpenSSLTLSIOCtx>(frontend, Private());
}

static std::shared_ptr<OpenSSLTLSIOCtx> createClientSideContext(const TLSContextParameters& params)
{
return std::make_shared<OpenSSLTLSIOCtx>(params, Private());
}

/* server side context */
OpenSSLTLSIOCtx(TLSFrontend& fe): d_feContext(std::make_shared<OpenSSLFrontendContext>(fe.d_addr, fe.d_tlsConfig))
OpenSSLTLSIOCtx(TLSFrontend& frontend, [[maybe_unused]] Private priv): d_feContext(std::make_unique<OpenSSLFrontendContext>(frontend.d_addr, frontend.d_tlsConfig))
{
OpenSSLTLSConnection::generateConnectionIndexIfNeeded();

d_ticketsKeyRotationDelay = fe.d_tlsConfig.d_ticketsKeyRotationDelay;
d_ticketsKeyRotationDelay = frontend.d_tlsConfig.d_ticketsKeyRotationDelay;

if (fe.d_tlsConfig.d_enableTickets && fe.d_tlsConfig.d_numberOfTicketsKeys > 0) {
if (frontend.d_tlsConfig.d_enableTickets && frontend.d_tlsConfig.d_numberOfTicketsKeys > 0) {
/* use our own ticket keys handler so we can rotate them */
#if OPENSSL_VERSION_MAJOR >= 3
SSL_CTX_set_tlsext_ticket_key_evp_cb(d_feContext->d_tlsCtx.get(), &OpenSSLTLSIOCtx::ticketKeyCb);
Expand All @@ -618,22 +638,22 @@ class OpenSSLTLSIOCtx: public TLSCtx
}
#endif /* DISABLE_OCSP_STAPLING */

if (fe.d_tlsConfig.d_readAhead) {
if (frontend.d_tlsConfig.d_readAhead) {
SSL_CTX_set_read_ahead(d_feContext->d_tlsCtx.get(), 1);
}

libssl_set_error_counters_callback(d_feContext->d_tlsCtx, &fe.d_tlsCounters);
libssl_set_error_counters_callback(d_feContext->d_tlsCtx, &frontend.d_tlsCounters);

if (!fe.d_tlsConfig.d_keyLogFile.empty()) {
d_feContext->d_keyLogFile = libssl_set_key_log_file(d_feContext->d_tlsCtx, fe.d_tlsConfig.d_keyLogFile);
if (!frontend.d_tlsConfig.d_keyLogFile.empty()) {
d_feContext->d_keyLogFile = libssl_set_key_log_file(d_feContext->d_tlsCtx, frontend.d_tlsConfig.d_keyLogFile);
}

try {
if (fe.d_tlsConfig.d_ticketKeyFile.empty()) {
if (frontend.d_tlsConfig.d_ticketKeyFile.empty()) {
handleTicketsKeyRotation(time(nullptr));
}
else {
OpenSSLTLSIOCtx::loadTicketsKeys(fe.d_tlsConfig.d_ticketKeyFile);
OpenSSLTLSIOCtx::loadTicketsKeys(frontend.d_tlsConfig.d_ticketKeyFile);
}
}
catch (const std::exception& e) {
Expand All @@ -642,7 +662,7 @@ class OpenSSLTLSIOCtx: public TLSCtx
}

/* client side context */
OpenSSLTLSIOCtx(const TLSContextParameters& params)
OpenSSLTLSIOCtx(const TLSContextParameters& params, [[maybe_unused]] Private priv)
{
int sslOptions =
SSL_OP_NO_SSLv2 |
Expand Down Expand Up @@ -733,6 +753,11 @@ class OpenSSLTLSIOCtx: public TLSCtx
#endif
}

OpenSSLTLSIOCtx(const OpenSSLTLSIOCtx&) = delete;
OpenSSLTLSIOCtx(OpenSSLTLSIOCtx&&) = delete;
OpenSSLTLSIOCtx& operator=(const OpenSSLTLSIOCtx&) = delete;
OpenSSLTLSIOCtx& operator=(OpenSSLTLSIOCtx&&) = delete;

~OpenSSLTLSIOCtx() override
{
d_tlsCtx.reset();
Expand Down Expand Up @@ -790,16 +815,24 @@ class OpenSSLTLSIOCtx: public TLSCtx
return 1;
}

SSL_CTX* getOpenSSLContext() const
{
if (d_feContext) {
return d_feContext->d_tlsCtx.get();
}
return d_tlsCtx.get();
}

std::unique_ptr<TLSConnection> getConnection(int socket, const struct timeval& timeout, time_t now) override
{
handleTicketsKeyRotation(now);

return std::make_unique<OpenSSLTLSConnection>(socket, timeout, d_feContext);
return std::make_unique<OpenSSLTLSConnection>(socket, timeout, shared_from_this(), std::unique_ptr<SSL, void(*)(SSL*)>(SSL_new(getOpenSSLContext()), SSL_free));
}

std::unique_ptr<TLSConnection> getClientConnection(const std::string& host, bool hostIsAddr, int socket, const struct timeval& timeout) override
{
auto conn = std::make_unique<OpenSSLTLSConnection>(host, hostIsAddr, socket, timeout, d_tlsCtx);
auto conn = std::make_unique<OpenSSLTLSConnection>(host, hostIsAddr, socket, timeout, shared_from_this(), std::unique_ptr<SSL, void(*)(SSL*)>(SSL_new(getOpenSSLContext()), SSL_free));
if (d_ktls) {
conn->enableKTLS();
}
Expand Down Expand Up @@ -834,24 +867,32 @@ class OpenSSLTLSIOCtx: public TLSCtx
return "openssl";
}

bool isServerContext() const
{
return d_feContext != nullptr;
}

bool setALPNProtos(const std::vector<std::vector<uint8_t>>& protos) override
{
if (d_feContext && d_feContext->d_tlsCtx) {
auto* openSSLContext = getOpenSSLContext();
if (openSSLContext == nullptr) {
return false;
}

if (isServerContext()) {
d_alpnProtos = protos;
libssl_set_alpn_select_callback(d_feContext->d_tlsCtx.get(), alpnServerSelectCallback, this);
libssl_set_alpn_select_callback(openSSLContext, alpnServerSelectCallback, this);
return true;
}
if (d_tlsCtx) {
return libssl_set_alpn_protos(d_tlsCtx.get(), protos);
}
return false;

return libssl_set_alpn_protos(openSSLContext, protos);
}

#ifndef DISABLE_NPN
bool setNextProtocolSelectCallback(bool(*cb)(unsigned char** out, unsigned char* outlen, const unsigned char* in, unsigned int inlen)) override
{
d_nextProtocolSelectCallback = cb;
libssl_set_npn_select_callback(d_tlsCtx.get(), npnSelectCallback, this);
libssl_set_npn_select_callback(getOpenSSLContext(), npnSelectCallback, this);
return true;
}
#endif /* DISABLE_NPN */
Expand Down Expand Up @@ -906,8 +947,8 @@ class OpenSSLTLSIOCtx: public TLSCtx
}

std::vector<std::vector<uint8_t>> d_alpnProtos; // store the supported ALPN protocols, so that the server can select based on what the client sent
std::shared_ptr<OpenSSLFrontendContext> d_feContext{nullptr};
std::shared_ptr<SSL_CTX> d_tlsCtx{nullptr}; // client context, on a server-side the context is stored in d_feContext->d_tlsCtx
std::unique_ptr<OpenSSLFrontendContext> d_feContext{nullptr};
bool (*d_nextProtocolSelectCallback)(unsigned char** out, unsigned char* outlen, const unsigned char* in, unsigned int inlen){nullptr};
bool d_ktls{false};
};
Expand Down Expand Up @@ -1857,13 +1898,13 @@ bool TLSFrontend::setupTLS()
#endif /* HAVE_GNUTLS */
#if defined(HAVE_LIBSSL)
if (d_provider == "openssl") {
newCtx = std::make_shared<OpenSSLTLSIOCtx>(*this);
newCtx = OpenSSLTLSIOCtx::createServerSideContext(*this);
}
#endif /* HAVE_LIBSSL */

if (!newCtx) {
#if defined(HAVE_LIBSSL)
newCtx = std::make_shared<OpenSSLTLSIOCtx>(*this);
newCtx = OpenSSLTLSIOCtx::createServerSideContext(*this);
#elif defined(HAVE_GNUTLS)
newCtx = std::make_shared<GnuTLSIOCtx>(*this);
#else
Expand Down Expand Up @@ -1895,13 +1936,13 @@ std::shared_ptr<TLSCtx> getTLSContext([[maybe_unused]] const TLSContextParameter
#endif /* HAVE_GNUTLS */
#if defined(HAVE_LIBSSL)
if (params.d_provider == "openssl") {
return std::make_shared<OpenSSLTLSIOCtx>(params);
return OpenSSLTLSIOCtx::createClientSideContext(params);
}
#endif /* HAVE_LIBSSL */
}

#if defined(HAVE_LIBSSL)
return std::make_shared<OpenSSLTLSIOCtx>(params);
return OpenSSLTLSIOCtx::createClientSideContext(params);
#elif defined(HAVE_GNUTLS)
return std::make_shared<GnuTLSIOCtx>(params);
#else
Expand Down