Skip to content

Commit

Permalink
MDClient: Keep error and RBAC data when handling responses
Browse files Browse the repository at this point in the history
Make sure all data (notably the error message) is kept when reading
responses.
  • Loading branch information
frankosterfeld authored and RalphSteinhagen committed Nov 7, 2023
1 parent be867a6 commit a245ee1
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 25 deletions.
43 changes: 19 additions & 24 deletions src/client/include/Client.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,22 +132,16 @@ class Client : public MDClientBase {
return true;
}

static bool handleMessage(const mdp::Message &message, mdp::Message &output) {
// subscription updates
static bool handleMessage(mdp::Message &message) {
if (message.command == mdp::Command::Notify || message.command == mdp::Command::Final) {
output.arrivalTime = std::chrono::system_clock::now();
const auto body = message.data.asString();
output.data.resize(body.size());
std::memcpy(output.data.data(), body.begin(), body.size());
output.endpoint = message.endpoint;
auto params = output.endpoint.queryParamMap();
auto requestId_sv = message.clientRequestID.asString();
if (auto result = std::from_chars(requestId_sv.data(), requestId_sv.data() + requestId_sv.size(), output.id); result.ec == std::errc::invalid_argument || result.ec == std::errc::result_out_of_range) {
output.id = 0;
message.arrivalTime = std::chrono::system_clock::now();
const auto requestId_sv = message.clientRequestID.asString();
if (auto result = std::from_chars(requestId_sv.data(), requestId_sv.data() + requestId_sv.size(), message.id); result.ec == std::errc::invalid_argument || result.ec == std::errc::result_out_of_range) {
message.id = 0;
}
return true;
}
return true;
return false;
}

bool receive(mdp::Message &msg) override {
Expand All @@ -156,7 +150,10 @@ class Client : public MDClientBase {
continue;
}
if (auto message = zmq::receive<mdp::MessageFormat::WithoutSourceId>(con._socket)) {
return handleMessage(std::move(*message), msg);
if (!handleMessage(*message))
return false;
msg = std::move(*message);
return true;
}
}
return false;
Expand Down Expand Up @@ -275,15 +272,17 @@ class SubscriptionClient : public MDClientBase {
return true;
}

static bool handleMessage(const mdp::BasicMessage<MessageFormat::WithSourceId> &message, mdp::Message &output) {
static bool handleMessage(mdp::BasicMessage<MessageFormat::WithSourceId> &&message, mdp::Message &output) {
// subscription updates
if (message.command == mdp::Command::Notify || message.command == mdp::Command::Final) {
output.arrivalTime = std::chrono::system_clock::now();
output.data = message.data;
output.data = std::move(message.data);
output.error = std::move(message.error);
// output.serviceName = URI<uri_check::STRICT>(std::string{ message.serviceName() });
output.serviceName = message.sourceId; // temporary hack until serviceName -> 'requestedTopic' and 'topic' -> 'replyTopic'
output.endpoint = message.endpoint;
output.clientRequestID = message.clientRequestID;
output.serviceName = std::move(message.sourceId); // temporary hack until serviceName -> 'requestedTopic' and 'topic' -> 'replyTopic'
output.endpoint = std::move(message.endpoint);
output.clientRequestID = std::move(message.clientRequestID);
output.rbac = std::move(message.rbac);
output.id = 0; // review if this is still needed
return true;
}
Expand All @@ -295,12 +294,8 @@ class SubscriptionClient : public MDClientBase {
if (con._connectionState != detail::ConnectionState::CONNECTED) {
continue;
}
while (true) {
if (auto message = zmq::receive<MessageFormat::WithSourceId>(con._socket)) {
return handleMessage(*message, msg);
} else {
break;
}
if (auto message = zmq::receive<MessageFormat::WithSourceId>(con._socket)) {
return handleMessage(std::move(*message), msg);
}
}
return false;
Expand Down
7 changes: 6 additions & 1 deletion src/client/test/ClientPublisher_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,20 @@ TEST_CASE("Basic get/set test", "[ClientContext]") {
auto endpoint = URI<STRICT>::factory(URI<STRICT>(server.address())).scheme("mdp").path("/a.service").addQueryParameter("C", "2").build();
std::atomic<int> received{ 0 };
clientContext.get(endpoint, [&received](const Message &message) {
REQUIRE(message.data.size() == 3); // == "100");
REQUIRE(message.data.asString() == "100");
REQUIRE(message.error == "404");
REQUIRE(message.rbac.asString() == "rbac");
received++;
});
std::this_thread::sleep_for(20ms); // allow the request to reach the server
server.processRequest([&endpoint](auto &&req, auto &reply) {
REQUIRE(req.command == Command::Get);
reply.data = IoBuffer("100");
reply.error = "404";
reply.rbac = IoBuffer("rbac");
reply.endpoint = Message::URI::factory(endpoint).addQueryParameter("ctx", "test_ctx").build();
});

std::this_thread::sleep_for(20ms); // hacky: this is needed because the requests are only identified using their uri, so we cannot have multiple requests with identical uris
auto testData = std::vector<std::byte>{ std::byte{ 'a' }, std::byte{ 'b' }, std::byte{ 'c' } };
opencmw::IoBuffer dataSetRequest;
Expand Down

0 comments on commit a245ee1

Please sign in to comment.