Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add an endpoint to check connectivity to WOPI server #9202

Merged
merged 4 commits into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 70 additions & 28 deletions net/HttpRequest.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1236,7 +1236,7 @@ class Session final : public ProtocolHandlerInterface
void setFinishedHandler(FinishedCallback onFinished) { _onFinished = std::move(onFinished); }

/// The onConnectFail callback handler signature.
using ConnectFailCallback = std::function<void()>;
using ConnectFailCallback = std::function<void(const std::shared_ptr<Session>& session)>;

void setConnectFailHandler(ConnectFailCallback onConnectFail) { _onConnectFail = std::move(onConnectFail); }

Expand Down Expand Up @@ -1380,6 +1380,11 @@ class Session final : public ProtocolHandlerInterface
}
}

net::AsyncConnectResult connectionResult()
{
return _result;
}

/// Returns the socket FD, for logging/informational purposes.
int getFD() const { return _fd; }

Expand Down Expand Up @@ -1448,6 +1453,30 @@ class Session final : public ProtocolHandlerInterface
return _response->state() == Response::State::Complete;
}

void callOnFinished()
{
if (!_onFinished)
return;

LOG_TRC("onFinished calling client");
std::shared_ptr<Session> self = shared_from_this();
try
{
[[maybe_unused]] const long references = self.use_count();
assert(references > 1 && "Expected more than 1 reference to http::Session.");

_onFinished(self);

assert(self.use_count() > 1 &&
"Erroneously onFinish reset 'this'. Use 'addCallback()' on the "
"SocketPoll to reset on idle instead.");
}
catch (const std::exception& exc)
{
LOG_ERR("Error while invoking onFinished client callback: " << exc.what());
}
}

/// Set up a new request and response.
void newRequest(const Request& req)
{
Expand All @@ -1468,26 +1497,8 @@ class Session final : public ProtocolHandlerInterface
assert(_response->state() != Response::State::Incomplete &&
"Unexpected response in Incomplete state");
assert(_response->done() && "Must have response in done state");
if (_onFinished)
{
LOG_TRC("onFinished calling client");
auto self = shared_from_this();
try
{
[[maybe_unused]] const auto references = self.use_count();
assert(references > 1 && "Expected more than 1 reference to http::Session.");

_onFinished(std::static_pointer_cast<Session>(self));

assert(self.use_count() > 1 &&
"Erroneously onFinish reset 'this'. Use 'addCallback()' on the "
"SocketPoll to reset on idle instead.");
}
catch (const std::exception& exc)
{
LOG_ERR("Error while invoking onFinished client callback: " << exc.what());
}
}
callOnFinished();

if (_response->header().getConnectionToken() == Header::ConnectionToken::Close)
{
Expand Down Expand Up @@ -1609,11 +1620,40 @@ class Session final : public ProtocolHandlerInterface

if (!socket->send(_request))
{
_result = net::AsyncConnectResult::SocketError;
LOG_ERR("Error while writing to socket");
}
}
}

std::shared_ptr<Session> shared_from_this()
{
return std::static_pointer_cast<Session>(ProtocolHandlerInterface::shared_from_this());
}

void callOnConnectFail()
{
if (!_onConnectFail)
return;

std::shared_ptr<Session> self = shared_from_this();
try
{
[[maybe_unused]] const long references = self.use_count();
assert(references > 1 && "Expected more than 1 reference to http::Session.");

_onConnectFail(self);

assert(self.use_count() > 1 &&
"Erroneously onConnectFail reset 'this'. Use 'addCallback()' on the "
"SocketPoll to reset on idle instead.");
}
catch (const std::exception& exc)
{
LOG_ERR("Error while invoking onConnectFail client callback: " << exc.what());
}
}

// on failure the stream will be discarded, so save the ssl verification
// result while it is still available
void onHandshakeFail() override
Expand All @@ -1623,7 +1663,10 @@ class Session final : public ProtocolHandlerInterface
{
LOG_TRC("onHandshakeFail");
_handshakeSslVerifyFailure = socket->getSslVerifyResult();
_result = net::AsyncConnectResult::SSLHandShakeFailure;
}

callOnConnectFail();
}

void onDisconnect() override
Expand Down Expand Up @@ -1659,22 +1702,20 @@ class Session final : public ProtocolHandlerInterface
return socket; // Return the shared pointer.
}

void asyncConnectCompleted(SocketPoll& poll, std::shared_ptr<StreamSocket> socket)
void asyncConnectCompleted(SocketPoll& poll, const std::shared_ptr<StreamSocket> &socket, net::AsyncConnectResult result)
{
assert((!socket || _fd == socket->getFD()) &&
"The socket FD must have been set in onConnect");

// When used with proxy.php we may indeed get nullptr here.
// assert(socket && "Unexpected nullptr returned from net::connect");
_socket = socket; // Hold a weak pointer to it.
_result = result;

if (!socket)
{
LOG_ERR("Failed to connect to " << _host << ':' << _port);

if (_onConnectFail)
_onConnectFail();

callOnConnectFail();
return;
}

Expand All @@ -1689,9 +1730,9 @@ class Session final : public ProtocolHandlerInterface
{
_socket.reset(); // Reset to make sure we are disconnected.

auto pushConnectCompleteToPoll = [this, &poll](std::shared_ptr<StreamSocket> socket) {
poll.addCallback([selfLifecycle = shared_from_this(), this, &poll, socket=std::move(socket)]() {
asyncConnectCompleted(poll, socket);
auto pushConnectCompleteToPoll = [this, &poll](std::shared_ptr<StreamSocket> socket, net::AsyncConnectResult result ) {
poll.addCallback([selfLifecycle = shared_from_this(), this, &poll, socket=std::move(socket), &result]() {
asyncConnectCompleted(poll, socket, result);
});
};

Expand Down Expand Up @@ -1744,6 +1785,7 @@ class Session final : public ProtocolHandlerInterface
ConnectFailCallback _onConnectFail;
std::shared_ptr<Response> _response;
std::weak_ptr<StreamSocket> _socket; ///< Must be the last member.
net::AsyncConnectResult _result; // last connection tentative result
};

/// HTTP Get a URL synchronously.
Expand Down
19 changes: 13 additions & 6 deletions net/NetUtil.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ asyncConnect(const std::string& host, const std::string& port, const bool isSSL,
if (host.empty() || port.empty())
{
LOG_ERR("Invalid host/port " << host << ':' << port);
asyncCb(nullptr);
asyncCb(nullptr, AsyncConnectResult::HostNameError);
return;
}

Expand All @@ -389,7 +389,7 @@ asyncConnect(const std::string& host, const std::string& port, const bool isSSL,
if (isSSL)
{
LOG_ERR("Error: isSSL socket requested but SSL is not compiled in.");
asyncCb(nullptr);
asyncCb(nullptr, asyncConnectResult::MissingSSLError);
return;
}
#endif
Expand All @@ -399,6 +399,8 @@ asyncConnect(const std::string& host, const std::string& port, const bool isSSL,
{
std::shared_ptr<StreamSocket> socket;

AsyncConnectResult result = AsyncConnectResult::UnknownHostError;

if (const addrinfo* ainfo = hostEntry.getAddrInfo())
{
for (const addrinfo* ai = ainfo; ai; ai = ai->ai_next)
Expand All @@ -408,13 +410,15 @@ asyncConnect(const std::string& host, const std::string& port, const bool isSSL,
int fd = ::socket(ai->ai_addr->sa_family, SOCK_STREAM | SOCK_NONBLOCK | SOCK_CLOEXEC, 0);
if (fd < 0)
{
result = AsyncConnectResult::SocketError;
LOG_SYS("Failed to create socket");
continue;
}

int res = ::connect(fd, ai->ai_addr, ai->ai_addrlen);
if (res < 0 && errno != EINPROGRESS)
{
result = AsyncConnectResult::ConnectionError;
LOG_SYS("Failed to connect to " << host);
::close(fd);
}
Expand All @@ -439,9 +443,12 @@ asyncConnect(const std::string& host, const std::string& port, const bool isSSL,
{
LOG_DBG('#' << fd << " New socket connected to " << host << ':' << port
<< " (" << (isSSL ? "SSL)" : "Unencrypted)"));
result = AsyncConnectResult::Ok;
break;
}

result = AsyncConnectResult::SocketError;

LOG_ERR("Failed to allocate socket for client websocket " << host);
::close(fd);
break;
Expand All @@ -452,7 +459,7 @@ asyncConnect(const std::string& host, const std::string& port, const bool isSSL,
else
LOG_SYS("Failed to lookup host [" << host << "]. Skipping");

asyncCb(std::move(socket));
asyncCb(std::move(socket), result);
};

net::AsyncDNS::DNSThreadDumpStateFn dumpState = [host, port]() -> std::string
Expand Down Expand Up @@ -570,7 +577,7 @@ connect(std::string uri, const std::shared_ptr<ProtocolHandlerInterface>& protoc
}

bool parseUri(std::string uri, std::string& scheme, std::string& host, std::string& port,
std::string& url)
std::string& pathAndQuery)
{
const auto itScheme = uri.find("://");
if (itScheme != uri.npos)
Expand All @@ -587,12 +594,12 @@ bool parseUri(std::string uri, std::string& scheme, std::string& host, std::stri
const auto itUrl = uri.find('/');
if (itUrl != uri.npos)
{
url = uri.substr(itUrl); // Including the first foreslash.
pathAndQuery = uri.substr(itUrl); // Including the first foreslash.
uri = uri.substr(0, itUrl);
}
else
{
url.clear();
pathAndQuery.clear();
}

const auto itPort = uri.find(':');
Expand Down
17 changes: 13 additions & 4 deletions net/NetUtil.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,16 @@ std::shared_ptr<StreamSocket>
connect(const std::string& host, const std::string& port, const bool isSSL,
const std::shared_ptr<ProtocolHandlerInterface>& protocolHandler);

typedef std::function<void(std::shared_ptr<StreamSocket>)> asyncConnectCB;
enum class AsyncConnectResult{
Ok = 0,
SocketError,
ConnectionError,
HostNameError,
UnknownHostError,
SSLHandShakeFailure,
};

typedef std::function<void(std::shared_ptr<StreamSocket>, AsyncConnectResult result)> asyncConnectCB;

void
asyncConnect(const std::string& host, const std::string& port, const bool isSSL,
Expand All @@ -103,14 +112,14 @@ connect(std::string uri, const std::shared_ptr<ProtocolHandlerInterface>& protoc
/// Decomposes a URI into its components.
/// Returns true if parsing was successful.
bool parseUri(std::string uri, std::string& scheme, std::string& host, std::string& port,
std::string& url);
std::string& pathAndQuery);

/// Decomposes a URI into its components.
/// Returns true if parsing was successful.
inline bool parseUri(std::string uri, std::string& scheme, std::string& host, std::string& port)
{
std::string url;
return parseUri(std::move(uri), scheme, host, port, url);
std::string pathAndQuery;
return parseUri(std::move(uri), scheme, host, port, pathAndQuery);
}

/// Return the locator given a URI.
Expand Down
6 changes: 3 additions & 3 deletions test/HttpRequestTests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ void HttpRequestTests::testSimpleGet()

std::unique_lock<std::mutex> lock(mutex);

httpSession->setConnectFailHandler([]() {
httpSession->setConnectFailHandler([](const std::shared_ptr<http::Session>&) {
LOK_ASSERT_FAIL("Unexpected connection failure");
});

Expand Down Expand Up @@ -535,7 +535,7 @@ void HttpRequestTests::test500GetStatuses()
std::unique_lock<std::mutex> lock(mutex);
timedout = true; // Assume we timed out until we prove otherwise.

httpSession->setConnectFailHandler([]() {
httpSession->setConnectFailHandler([](const std::shared_ptr<http::Session>&) {
LOK_ASSERT_FAIL("Unexpected connection failure");
});

Expand Down Expand Up @@ -628,7 +628,7 @@ void HttpRequestTests::testSimplePost_External()

std::unique_lock<std::mutex> lock(mutex);

httpSession->setConnectFailHandler([]() {
httpSession->setConnectFailHandler([](const std::shared_ptr<http::Session>&) {
LOK_ASSERT_FAIL("Unexpected connection failure");
});

Expand Down
2 changes: 1 addition & 1 deletion test/UnitProxy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class UnitProxy : public UnitWSD
// Request from rating.collaboraonline.com.
_req = http::Request("/browser/a90f83c/foo/remote/static/lokit-extra-img.svg");

httpSession->setConnectFailHandler([this]() {
httpSession->setConnectFailHandler([this](const std::shared_ptr<http::Session>&) {
LOK_ASSERT_FAIL("Unexpected connection failure");
});

Expand Down
Loading
Loading