From 2969e70491d77d6addc5201bf556f6197a2b9f94 Mon Sep 17 00:00:00 2001 From: Galarius Date: Sun, 23 Jul 2023 22:54:07 +0300 Subject: [PATCH 01/32] Refactor jsonrpc --- include/jsonrpc.hpp | 22 +++-- src/jsonrpc.cpp | 218 +++++++++++++++++++++++++++++--------------- tests/main.cpp | 80 ++++++++++------ 3 files changed, 208 insertions(+), 112 deletions(-) diff --git a/include/jsonrpc.hpp b/include/jsonrpc.hpp index f8d8b4d..5b3f029 100644 --- a/include/jsonrpc.hpp +++ b/include/jsonrpc.hpp @@ -2,7 +2,7 @@ // jsonrpc.hpp // opencl-language-server // -// Created by Ilya Shoshin (Galarius) on 7/16/21. +// Created by Ilia Shoshin on 7/16/21. // #pragma once @@ -22,6 +22,7 @@ class JsonRPC using OutputCallbackFunc = std::function; public: + // clang-format off enum class ErrorCode : int { ///@{ @@ -30,12 +31,12 @@ class JsonRPC InvalidRequest = -32600, ///< Invalid Request The JSON sent is not a valid Request object. MethodNotFound = -32601, ///< Method not found The method does not exist / is not available. InvalidParams = -32602, ///< Invalid params Invalid method parameter(s). - InternalError = - -32603, ///< Internal error Internal JSON-RPC error. - // -32000 to -32099 Server error Reserved for implementation-defined server-errors. + InternalError = -32603, ///< Internal error Internal JSON-RPC error. + // -32000 to -32099 Server error Reserved for implementation-defined server-errors. NotInitialized = -32002 ///< The first client's message is not equal to "initialize" - ///@} + ///@} }; + // clang-format on friend std::ostream& operator<<(std::ostream& out, ErrorCode const& code) { @@ -65,16 +66,25 @@ class JsonRPC /** Send trace message to client. */ - void LogTrace(const std::string& message, const std::string& verbose = ""); + void WriteTrace(const std::string& message, const std::string& verbose); void WriteError(JsonRPC::ErrorCode errorCode, const std::string& message) const; private: + void ProcessBufferContent(); + void ProcessMethod(); + void ProcessBufferHeader(); + void OnInitialize(); void OnTracingChanged(const nlohmann::json& data); bool ReadHeader(); void FireMethodCallback(); void FireRespondCallback(); + void LogBufferContent() const; + void LogMessage(const std::string& message) const; + void LogAndHandleParseError(std::exception& e); + void LogAndHandleUnexpectedMessage(); + private: std::string m_method; std::string m_buffer; diff --git a/src/jsonrpc.cpp b/src/jsonrpc.cpp index 7de3a57..fdcbcbb 100644 --- a/src/jsonrpc.cpp +++ b/src/jsonrpc.cpp @@ -2,7 +2,7 @@ // jsonrpc.cpp // opencl-language-server // -// Created by Ilya Shoshin (Galarius) on 7/16/21. +// Created by Ilia Shoshin on 7/16/21. // #include "jsonrpc.hpp" @@ -13,8 +13,15 @@ using namespace nlohmann; namespace ocls { namespace { + constexpr char logger[] = "jrpc"; constexpr char LE[] = "\r\n"; + +inline std::string FormatBool(bool flag) +{ + return flag ? "yes" : "no"; +} + } // namespace void JsonRPC::RegisterMethodCallback(const std::string& method, InputCallbackFunc&& func) @@ -38,73 +45,14 @@ void JsonRPC::RegisterOutputCallback(OutputCallbackFunc&& func) void JsonRPC::Consume(char c) { m_buffer += c; + if (m_validHeader) { - if (m_buffer.length() != m_contentLength) - return; - try - { - spdlog::get(logger)->debug(""); - spdlog::get(logger)->debug(">>>>>>>>>>>>>>>>"); - for (auto& header : m_headers) - spdlog::get(logger)->debug(header.first, ": ", header.second); - spdlog::get(logger)->debug(m_buffer); - spdlog::get(logger)->debug(">>>>>>>>>>>>>>>>"); - spdlog::get(logger)->debug(""); - - m_body = json::parse(m_buffer); - const auto method = m_body["method"]; - if (method.is_string()) - { - m_method = method.get(); - if (m_method == "initialize") - { - OnInitialize(); - } - else if (!m_initialized) - { - spdlog::get(logger)->error("Unexpected first message: '{}'", m_method); - WriteError(ErrorCode::NotInitialized, "Server was not initialized."); - return; - } - else if (m_method == "$/setTrace") - { - OnTracingChanged(m_body); - } - FireMethodCallback(); - } - else - { - FireRespondCallback(); - } - m_isProcessing = false; - } - catch (std::exception& e) - { - spdlog::get(logger)->error("Failed to parse request with reason: '{}'\n{}", e.what(), "\n", m_buffer); - m_buffer.clear(); - WriteError(ErrorCode::ParseError, "Failed to parse request"); - return; - } + ProcessBufferContent(); } else { - if (ReadHeader()) - m_buffer.clear(); - - if (m_buffer == LE) - { - m_buffer.clear(); - m_validHeader = m_contentLength > 0; - if (m_validHeader) - { - m_buffer.reserve(m_contentLength); - } - else - { - WriteError(ErrorCode::InvalidRequest, "Invalid content length"); - } - } + ProcessBufferHeader(); } } @@ -128,11 +76,7 @@ void JsonRPC::Write(const json& data) const message.append(LE); message.append(content); - spdlog::get(logger)->debug(""); - spdlog::get(logger)->debug("<<<<<<<<<<<<<<<<"); - spdlog::get(logger)->debug(message); - spdlog::get(logger)->debug("<<<<<<<<<<<<<<<<"); - spdlog::get(logger)->debug(""); + LogMessage(message); m_outputCallback(message); } @@ -153,7 +97,7 @@ void JsonRPC::Reset() m_isProcessing = true; } -void JsonRPC::LogTrace(const std::string& message, const std::string& verbose) +void JsonRPC::WriteTrace(const std::string& message, const std::string& verbose) { if (!m_tracing) { @@ -166,11 +110,88 @@ void JsonRPC::LogTrace(const std::string& message, const std::string& verbose) { spdlog::get(logger)->debug("JRPC verbose tracing is disabled"); spdlog::get(logger)->trace("The verbose message was: {}", verbose); + } + + // clang-format off + Write(json({ + {"method", "$/logTrace"}, + {"params", { + {"message", message}, + {"verbose", m_verbosity ? verbose : "" + }}} + })); + // clang-format on +} + +// private + +void JsonRPC::ProcessBufferContent() +{ + if (m_buffer.length() != m_contentLength) + { return; } - Write( - json({{"method", "$/logTrace"}, {"params", {{"message", message}, {"verbose", m_verbosity ? verbose : ""}}}})); + try + { + LogBufferContent(); + + m_body = json::parse(m_buffer); + const auto method = m_body["method"]; + if (method.is_string()) + { + m_method = method.get(); + ProcessMethod(); + } + else + { + FireRespondCallback(); + } + m_isProcessing = false; + } + catch (std::exception& e) + { + LogAndHandleParseError(e); + } +} + +void JsonRPC::ProcessMethod() +{ + if (m_method == "initialize") + { + OnInitialize(); + } + else if (!m_initialized) + { + return LogAndHandleUnexpectedMessage(); + } + else if (m_method == "$/setTrace") + { + OnTracingChanged(m_body); + } + FireMethodCallback(); +} + +void JsonRPC::ProcessBufferHeader() +{ + if (ReadHeader()) + { + m_buffer.clear(); + } + + if (m_buffer == LE) + { + m_buffer.clear(); + m_validHeader = m_contentLength > 0; + if (m_validHeader) + { + m_buffer.reserve(m_contentLength); + } + else + { + WriteError(ErrorCode::InvalidRequest, "Invalid content length"); + } + } } void JsonRPC::OnInitialize() @@ -182,7 +203,7 @@ void JsonRPC::OnInitialize() m_verbosity = traceValue == "verbose"; m_initialized = true; spdlog::get(logger)->debug( - "Tracing options: is verbose: {}, is on: {}", m_verbosity ? "yes" : "no", m_tracing ? "yes" : "no"); + "Tracing options: is verbose: {}, is on: {}", FormatBool(m_verbosity), FormatBool(m_tracing)); } catch (std::exception& err) { @@ -198,9 +219,7 @@ void JsonRPC::OnTracingChanged(const json& data) m_tracing = traceValue != "off"; m_verbosity = traceValue == "verbose"; spdlog::get(logger)->debug( - "Tracing options were changed, is verbose: {}, is on: {}", - m_verbosity ? "yes" : "no", - m_tracing ? "yes" : "no"); + "Tracing options were changed, is verbose: {}, is on: {}", FormatBool(m_verbosity), FormatBool(m_tracing)); } catch (std::exception& err) { @@ -219,7 +238,9 @@ bool JsonRPC::ReadHeader() std::string key = match.str(1); std::string value = match.str(2); if (key == "Content-Length") + { m_contentLength = std::stoi(value); + } m_headers[key] = value; ++next; } @@ -244,7 +265,7 @@ void JsonRPC::FireMethodCallback() const bool isRequest = m_body["params"]["id"] != nullptr; const bool mustRespond = isRequest || m_method.rfind("$/", 0) == std::string::npos; spdlog::get(logger)->debug( - "Got request: {}, respond is required: {}", isRequest ? "yes" : "no", mustRespond ? "yes" : "no"); + "Got request: {}, respond is required: {}", FormatBool(isRequest), FormatBool(mustRespond)); if (mustRespond) { WriteError(ErrorCode::MethodNotFound, "Method '" + m_method + "' is not supported."); @@ -276,4 +297,49 @@ void JsonRPC::WriteError(JsonRPC::ErrorCode errorCode, const std::string& messag Write(obj); } +void JsonRPC::LogBufferContent() const +{ + if (spdlog::get_level() > spdlog::level::debug) + { + return; + } + + spdlog::get(logger)->debug(""); + spdlog::get(logger)->debug(">>>>>>>>>>>>>>>>"); + for (auto& header : m_headers) + { + spdlog::get(logger)->debug(header.first, ": ", header.second); + } + spdlog::get(logger)->debug(m_buffer); + spdlog::get(logger)->debug(">>>>>>>>>>>>>>>>"); + spdlog::get(logger)->debug(""); +} + +void JsonRPC::LogMessage(const std::string& message) const +{ + if (spdlog::get_level() > spdlog::level::debug) + { + return; + } + + spdlog::get(logger)->debug(""); + spdlog::get(logger)->debug("<<<<<<<<<<<<<<<<"); + spdlog::get(logger)->debug(message); + spdlog::get(logger)->debug("<<<<<<<<<<<<<<<<"); + spdlog::get(logger)->debug(""); +} + +void JsonRPC::LogAndHandleParseError(std::exception& e) +{ + spdlog::get(logger)->error("Failed to parse request with reason: '{}'\n{}", e.what(), "\n", m_buffer); + m_buffer.clear(); + WriteError(ErrorCode::ParseError, "Failed to parse request"); +} + +void JsonRPC::LogAndHandleUnexpectedMessage() +{ + spdlog::get(logger)->error("Unexpected first message: '{}'", m_method); + WriteError(ErrorCode::NotInitialized, "Server was not initialized."); +} + } // namespace ocls diff --git a/tests/main.cpp b/tests/main.cpp index 0aca521..ac7b42b 100644 --- a/tests/main.cpp +++ b/tests/main.cpp @@ -51,10 +51,7 @@ std::string ParseResponse(std::string str) } const std::string initRequest = BuildRequest(json::object( - {{"jsonrpc", "2.0"}, - {"id", 0}, - {"method", "initialize"}, - {"params", {{"processId", 60650}, {"trace", "off"}}}})); + {{"jsonrpc", "2.0"}, {"id", 0}, {"method", "initialize"}, {"params", {{"processId", 60650}, {"trace", "off"}}}})); const auto InitializeJsonRPC = [](JsonRPC& jrpc) { jrpc.RegisterOutputCallback([](const std::string&) {}); @@ -67,58 +64,81 @@ const auto InitializeJsonRPC = [](JsonRPC& jrpc) { TEST(JsonRPCTest, InvalidRequestHandling) { + JsonRPC jrpc; + json response; const std::string request = R"!({"jsonrpc: 2.0", "id":0, [method]: "initialize"})!"; const std::string message = BuildRequest(request); - JsonRPC jrpc; - jrpc.RegisterOutputCallback([](const std::string& message) { - auto response = json::parse(ParseResponse(message)); - const auto code = response["error"]["code"].get(); - EXPECT_EQ(code, static_cast(JsonRPC::ErrorCode::ParseError)); - }); + jrpc.RegisterOutputCallback( + [&response](const std::string& message) { response = json::parse(ParseResponse(message)); }); + Send(message, jrpc); + + const auto code = response["error"]["code"].get(); + EXPECT_EQ(code, static_cast(JsonRPC::ErrorCode::ParseError)); } TEST(JsonRPCTest, OutOfOrderRequest) { - const std::string message = BuildRequest( - json::object({{"jsonrpc", "2.0"}, {"id", 0}, {"method", "textDocument/didOpen"}, {"params", {}}})); - JsonRPC jrpc; - jrpc.RegisterOutputCallback([](const std::string& message) { - auto response = json::parse(ParseResponse(message)); - const auto code = response["error"]["code"].get(); - EXPECT_EQ(code, static_cast(JsonRPC::ErrorCode::NotInitialized)); - }); + json response; + const std::string message = + BuildRequest(json::object({{"jsonrpc", "2.0"}, {"id", 0}, {"method", "textDocument/didOpen"}, {"params", {}}})); + jrpc.RegisterOutputCallback( + [&response](const std::string& message) { response = json::parse(ParseResponse(message)); }); + Send(message, jrpc); + + const auto code = response["error"]["code"].get(); + EXPECT_EQ(code, static_cast(JsonRPC::ErrorCode::NotInitialized)); } TEST(JsonRPCTest, MethodInitialize) { JsonRPC jrpc; + int64_t processId = 0; jrpc.RegisterOutputCallback([](const std::string&) {}); - jrpc.RegisterMethodCallback("initialize", [](const json& request) { - const auto& processId = request["params"]["processId"].get(); - EXPECT_EQ(processId, 60650); - }); + jrpc.RegisterMethodCallback( + "initialize", [&processId](const json& request) { processId = request["params"]["processId"].get(); }); + Send(initRequest, jrpc); + + EXPECT_EQ(processId, 60650); } TEST(JsonRPCTest, RespondToUnsupportedMethod) { JsonRPC jrpc; InitializeJsonRPC(jrpc); + json response; + const std::string request = + BuildRequest(json::object({{"jsonrpc", "2.0"}, {"id", 0}, {"method", "textDocument/didOpen"}, {"params", {}}})); + jrpc.RegisterOutputCallback( + [&response](const std::string& message) { response = json::parse(ParseResponse(message)); }); + // send unsupported request - const std::string request = BuildRequest( - json::object({{"jsonrpc", "2.0"}, {"id", 0}, {"method", "textDocument/didOpen"}, {"params", {}}})); - jrpc.RegisterOutputCallback([](const std::string& message) { - auto response = json::parse(ParseResponse(message)); - const auto code = response["error"]["code"].get(); - EXPECT_EQ(code, static_cast(JsonRPC::ErrorCode::MethodNotFound)); - }); Send(request, jrpc); + + const auto code = response["error"]["code"].get(); + EXPECT_EQ(code, static_cast(JsonRPC::ErrorCode::MethodNotFound)); +} + +TEST(JsonRPCTest, RespondToSupportedMethod) +{ + JsonRPC jrpc; + InitializeJsonRPC(jrpc); + bool isCallbackCalled = false; + const std::string request = + BuildRequest(json::object({{"jsonrpc", "2.0"}, {"id", 0}, {"method", "textDocument/didOpen"}, {"params", {}}})); + jrpc.RegisterMethodCallback( + "textDocument/didOpen", [&isCallbackCalled]([[maybe_unused]] const json& request) { isCallbackCalled = true; }); + + Send(request, jrpc); + + EXPECT_TRUE(isCallbackCalled); } -int main(int argc, char **argv) { +int main(int argc, char** argv) +{ ::testing::InitGoogleTest(&argc, argv); auto sink = std::make_shared(); auto mainLogger = std::make_shared("opencl-language-server", sink); From dbaa6946b0e5eef6d635175f479fffec4433fa02 Mon Sep 17 00:00:00 2001 From: Galarius Date: Sun, 23 Jul 2023 22:54:34 +0300 Subject: [PATCH 02/32] Fix clang-format toggle --- src/lsp.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lsp.cpp b/src/lsp.cpp index c41c9c3..3f4903f 100644 --- a/src/lsp.cpp +++ b/src/lsp.cpp @@ -291,7 +291,7 @@ int LSPServer::Run() std::cout << message << std::flush; #endif }); - // clang-format off + // clang-format on spdlog::get(logger)->info("Listening..."); char c; From f188d124373b33509ced1863a93cb9d781d4c57f Mon Sep 17 00:00:00 2001 From: Galarius Date: Sun, 23 Jul 2023 22:54:47 +0300 Subject: [PATCH 03/32] Update gitignore --- .gitignore | 2 +- builder/.gitignore | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) create mode 100644 builder/.gitignore diff --git a/.gitignore b/.gitignore index 92cd44a..4ed1873 100644 --- a/.gitignore +++ b/.gitignore @@ -20,6 +20,7 @@ xcuserdata/ *.xccheckout *.xcscmblueprint .DS_Store +Testing ## Conan .conan* @@ -34,6 +35,5 @@ test_package/* .vscode CMakePresets.json - ## Python .pyenv \ No newline at end of file diff --git a/builder/.gitignore b/builder/.gitignore new file mode 100644 index 0000000..ed8ebf5 --- /dev/null +++ b/builder/.gitignore @@ -0,0 +1 @@ +__pycache__ \ No newline at end of file From 31b865b75f5ce18adde539145e0d07323fd1a59d Mon Sep 17 00:00:00 2001 From: Galarius Date: Mon, 24 Jul 2023 23:30:05 +0300 Subject: [PATCH 04/32] Define IJsonRPC interface and hide JsonRPC class impl --- include/jsonrpc.hpp | 104 ++++++++++++++------------------------------ src/jsonrpc.cpp | 86 +++++++++++++++++++++++++++++++++--- src/lsp.cpp | 36 +++++++-------- tests/main.cpp | 42 +++++++++--------- 4 files changed, 154 insertions(+), 114 deletions(-) diff --git a/include/jsonrpc.hpp b/include/jsonrpc.hpp index 5b3f029..a29afbd 100644 --- a/include/jsonrpc.hpp +++ b/include/jsonrpc.hpp @@ -8,98 +8,58 @@ #pragma once #include -#include +#include #include -#include -#include -#include namespace ocls { -class JsonRPC -{ - using InputCallbackFunc = std::function; - using OutputCallbackFunc = std::function; - -public: - // clang-format off - enum class ErrorCode : int - { - ///@{ - ParseError = -32700, ///< Parse error Invalid JSON was received by the server. An error occurred on the - ///< server while parsing the JSON text. - InvalidRequest = -32600, ///< Invalid Request The JSON sent is not a valid Request object. - MethodNotFound = -32601, ///< Method not found The method does not exist / is not available. - InvalidParams = -32602, ///< Invalid params Invalid method parameter(s). - InternalError = -32603, ///< Internal error Internal JSON-RPC error. - // -32000 to -32099 Server error Reserved for implementation-defined server-errors. - NotInitialized = -32002 ///< The first client's message is not equal to "initialize" - ///@} - }; - // clang-format on +using InputCallbackFunc = std::function; +using OutputCallbackFunc = std::function; - friend std::ostream& operator<<(std::ostream& out, ErrorCode const& code) - { - out << static_cast(code); - return out; - } +// clang-format off +enum class JRPCErrorCode : int +{ + ///@{ + ParseError = -32700, ///< Parse error Invalid JSON was received by the server. An error occurred on the + ///< server while parsing the JSON text. + InvalidRequest = -32600, ///< Invalid Request The JSON sent is not a valid Request object. + MethodNotFound = -32601, ///< Method not found The method does not exist / is not available. + InvalidParams = -32602, ///< Invalid params Invalid method parameter(s). + InternalError = -32603, ///< Internal error Internal JSON-RPC error. + // -32000 to -32099 Server error Reserved for implementation-defined server-errors. + NotInitialized = -32002 ///< The first client's message is not equal to "initialize" + ///@} +}; +// clang-format on +struct IJsonRPC +{ /** Register callback to be notified on the specific method notification. All unregistered notifications will be responded with MethodNotFound automatically. - */ - void RegisterMethodCallback(const std::string& method, InputCallbackFunc&& func); + */ + virtual void RegisterMethodCallback(const std::string& method, InputCallbackFunc&& func) = 0; /** Register callback to be notified on client responds to server (our) requests. */ - void RegisterInputCallback(InputCallbackFunc&& func); + virtual void RegisterInputCallback(InputCallbackFunc&& func) = 0; /** Register callback to be notified when server is going to send the final message to the client. Basically it should be redirected to the stdout. */ - void RegisterOutputCallback(OutputCallbackFunc&& func); + virtual void RegisterOutputCallback(OutputCallbackFunc&& func) = 0; - void Consume(char c); - bool IsReady() const; - void Write(const nlohmann::json& data) const; - void Reset(); + virtual void Consume(char c) = 0; + virtual bool IsReady() const = 0; + virtual void Write(const nlohmann::json& data) const = 0; + virtual void Reset() = 0; /** Send trace message to client. */ - void WriteTrace(const std::string& message, const std::string& verbose); - void WriteError(JsonRPC::ErrorCode errorCode, const std::string& message) const; - -private: - void ProcessBufferContent(); - void ProcessMethod(); - void ProcessBufferHeader(); - - void OnInitialize(); - void OnTracingChanged(const nlohmann::json& data); - bool ReadHeader(); - void FireMethodCallback(); - void FireRespondCallback(); - - void LogBufferContent() const; - void LogMessage(const std::string& message) const; - void LogAndHandleParseError(std::exception& e); - void LogAndHandleUnexpectedMessage(); - -private: - std::string m_method; - std::string m_buffer; - nlohmann::json m_body; - std::unordered_map m_headers; - std::unordered_map m_callbacks; - OutputCallbackFunc m_outputCallback; - InputCallbackFunc m_respondCallback; - bool m_isProcessing = true; - bool m_initialized = false; - bool m_validHeader = false; - bool m_tracing = false; - bool m_verbosity = false; - unsigned long m_contentLength = 0; - std::regex m_headerRegex {"([\\w-]+): (.+)\\r\\n(?:([^:]+)\\r\\n)?"}; + virtual void WriteTrace(const std::string& message, const std::string& verbose) = 0; + virtual void WriteError(JRPCErrorCode errorCode, const std::string& message) const = 0; }; +std::shared_ptr CreateJsonRPC(); + } // namespace ocls diff --git a/src/jsonrpc.cpp b/src/jsonrpc.cpp index fdcbcbb..15e53d6 100644 --- a/src/jsonrpc.cpp +++ b/src/jsonrpc.cpp @@ -6,7 +6,11 @@ // #include "jsonrpc.hpp" + +#include +#include #include +#include using namespace nlohmann; @@ -24,6 +28,73 @@ inline std::string FormatBool(bool flag) } // namespace +class JsonRPC final : public IJsonRPC +{ +public: + friend std::ostream& operator<<(std::ostream& out, JRPCErrorCode const& code) + { + out << static_cast(code); + return out; + } + + /** + Register callback to be notified on the specific method notification. + All unregistered notifications will be responded with MethodNotFound automatically. + */ + void RegisterMethodCallback(const std::string& method, InputCallbackFunc&& func); + /** + Register callback to be notified on client responds to server (our) requests. + */ + void RegisterInputCallback(InputCallbackFunc&& func); + /** + Register callback to be notified when server is going to send the final message to the client. + Basically it should be redirected to the stdout. + */ + void RegisterOutputCallback(OutputCallbackFunc&& func); + + void Consume(char c); + bool IsReady() const; + void Write(const nlohmann::json& data) const; + void Reset(); + /** + Send trace message to client. + */ + void WriteTrace(const std::string& message, const std::string& verbose); + void WriteError(JRPCErrorCode errorCode, const std::string& message) const; + +private: + void ProcessBufferContent(); + void ProcessMethod(); + void ProcessBufferHeader(); + + void OnInitialize(); + void OnTracingChanged(const nlohmann::json& data); + bool ReadHeader(); + void FireMethodCallback(); + void FireRespondCallback(); + + void LogBufferContent() const; + void LogMessage(const std::string& message) const; + void LogAndHandleParseError(std::exception& e); + void LogAndHandleUnexpectedMessage(); + +private: + std::string m_method; + std::string m_buffer; + nlohmann::json m_body; + std::unordered_map m_headers; + std::unordered_map m_callbacks; + OutputCallbackFunc m_outputCallback; + InputCallbackFunc m_respondCallback; + bool m_isProcessing = true; + bool m_initialized = false; + bool m_validHeader = false; + bool m_tracing = false; + bool m_verbosity = false; + unsigned long m_contentLength = 0; + std::regex m_headerRegex {"([\\w-]+): (.+)\\r\\n(?:([^:]+)\\r\\n)?"}; +}; + void JsonRPC::RegisterMethodCallback(const std::string& method, InputCallbackFunc&& func) { spdlog::get(logger)->trace("Set callback for method: {}", method); @@ -189,7 +260,7 @@ void JsonRPC::ProcessBufferHeader() } else { - WriteError(ErrorCode::InvalidRequest, "Invalid content length"); + WriteError(JRPCErrorCode::InvalidRequest, "Invalid content length"); } } } @@ -268,7 +339,7 @@ void JsonRPC::FireMethodCallback() "Got request: {}, respond is required: {}", FormatBool(isRequest), FormatBool(mustRespond)); if (mustRespond) { - WriteError(ErrorCode::MethodNotFound, "Method '" + m_method + "' is not supported."); + WriteError(JRPCErrorCode::MethodNotFound, "Method '" + m_method + "' is not supported."); } } else @@ -285,7 +356,7 @@ void JsonRPC::FireMethodCallback() } } -void JsonRPC::WriteError(JsonRPC::ErrorCode errorCode, const std::string& message) const +void JsonRPC::WriteError(JRPCErrorCode errorCode, const std::string& message) const { spdlog::get(logger)->trace("Reporting error: '{}' ({})", message, static_cast(errorCode)); json obj = { @@ -333,13 +404,18 @@ void JsonRPC::LogAndHandleParseError(std::exception& e) { spdlog::get(logger)->error("Failed to parse request with reason: '{}'\n{}", e.what(), "\n", m_buffer); m_buffer.clear(); - WriteError(ErrorCode::ParseError, "Failed to parse request"); + WriteError(JRPCErrorCode::ParseError, "Failed to parse request"); } void JsonRPC::LogAndHandleUnexpectedMessage() { spdlog::get(logger)->error("Unexpected first message: '{}'", m_method); - WriteError(ErrorCode::NotInitialized, "Server was not initialized."); + WriteError(JRPCErrorCode::NotInitialized, "Server was not initialized."); +} + +std::shared_ptr CreateJsonRPC() +{ + return std::shared_ptr(new JsonRPC()); } } // namespace ocls diff --git a/src/lsp.cpp b/src/lsp.cpp index 3f4903f..1cf6b65 100644 --- a/src/lsp.cpp +++ b/src/lsp.cpp @@ -11,8 +11,8 @@ #include "utils.hpp" #include +#include #include - #include using namespace nlohmann; @@ -32,7 +32,9 @@ class LSPServer final , public std::enable_shared_from_this { public: - LSPServer() : m_diagnostics(CreateDiagnostics(CreateCLInfo())) {} + LSPServer() + : m_jrpc(CreateJsonRPC()) + , m_diagnostics(CreateDiagnostics(CreateCLInfo())) {} int Run(); void Interrupt(); @@ -50,7 +52,7 @@ class LSPServer final void OnExit(); private: - JsonRPC m_jrpc; + std::shared_ptr m_jrpc; std::shared_ptr m_diagnostics; std::queue m_outQueue; Capabilities m_capabilities; @@ -159,7 +161,7 @@ void LSPServer::BuildDiagnosticsRespond(const std::string &uri, const std::strin { auto msg = std::string("Failed to get diagnostics: ") + err.what(); spdlog::get(logger)->error(msg); - m_jrpc.WriteError(JsonRPC::ErrorCode::InternalError, msg); + m_jrpc->WriteError(JRPCErrorCode::InternalError, msg); } } @@ -248,41 +250,41 @@ int LSPServer::Run() auto self = this->shared_from_this(); // clang-format off // Register handlers for methods - m_jrpc.RegisterMethodCallback("initialize", [self](const json &request) + m_jrpc->RegisterMethodCallback("initialize", [self](const json &request) { self->OnInitialize(request); }); - m_jrpc.RegisterMethodCallback("initialized", [self](const json &request) + m_jrpc->RegisterMethodCallback("initialized", [self](const json &request) { self->OnInitialized(request); }); - m_jrpc.RegisterMethodCallback("shutdown", [self](const json &request) + m_jrpc->RegisterMethodCallback("shutdown", [self](const json &request) { self->OnShutdown(request); }); - m_jrpc.RegisterMethodCallback("exit", [self](const json &) + m_jrpc->RegisterMethodCallback("exit", [self](const json &) { self->OnExit(); }); - m_jrpc.RegisterMethodCallback("textDocument/didOpen", [self](const json &request) + m_jrpc->RegisterMethodCallback("textDocument/didOpen", [self](const json &request) { self->OnTextOpen(request); }); - m_jrpc.RegisterMethodCallback("textDocument/didChange", [self](const json &request) + m_jrpc->RegisterMethodCallback("textDocument/didChange", [self](const json &request) { self->OnTextChanged(request); }); - m_jrpc.RegisterMethodCallback("workspace/didChangeConfiguration", [self](const json &) + m_jrpc->RegisterMethodCallback("workspace/didChangeConfiguration", [self](const json &) { self->GetConfiguration(); }); // Register handler for client responds - m_jrpc.RegisterInputCallback([self](const json &respond) + m_jrpc->RegisterInputCallback([self](const json &respond) { self->OnRespond(respond); }); // Register handler for message delivery - m_jrpc.RegisterOutputCallback([](const std::string &message) + m_jrpc->RegisterOutputCallback([](const std::string &message) { #if defined(WIN32) printf_s("%s", message.c_str()); @@ -300,14 +302,14 @@ int LSPServer::Run() if(m_interrupted.load()) { return EINTR; } - m_jrpc.Consume(c); - if (m_jrpc.IsReady()) + m_jrpc->Consume(c); + if (m_jrpc->IsReady()) { - m_jrpc.Reset(); + m_jrpc->Reset(); while (!m_outQueue.empty()) { auto data = m_outQueue.front(); - m_jrpc.Write(data); + m_jrpc->Write(data); m_outQueue.pop(); } } diff --git a/tests/main.cpp b/tests/main.cpp index ac7b42b..b55f23b 100644 --- a/tests/main.cpp +++ b/tests/main.cpp @@ -33,10 +33,12 @@ std::string BuildRequest(const json& obj) return BuildRequest(obj.dump()); } -void Send(const std::string& request, JsonRPC& jrpc) +void Send(const std::string& request, std::shared_ptr jrpc) { for (auto c : request) - jrpc.Consume(c); + { + jrpc->Consume(c); + } } std::string ParseResponse(std::string str) @@ -53,51 +55,51 @@ std::string ParseResponse(std::string str) const std::string initRequest = BuildRequest(json::object( {{"jsonrpc", "2.0"}, {"id", 0}, {"method", "initialize"}, {"params", {{"processId", 60650}, {"trace", "off"}}}})); -const auto InitializeJsonRPC = [](JsonRPC& jrpc) { - jrpc.RegisterOutputCallback([](const std::string&) {}); - jrpc.RegisterMethodCallback("initialize", [](const json&) {}); +const auto InitializeJsonRPC = [](std::shared_ptr jrpc) { + jrpc->RegisterOutputCallback([](const std::string&) {}); + jrpc->RegisterMethodCallback("initialize", [](const json&) {}); Send(initRequest, jrpc); - jrpc.Reset(); + jrpc->Reset(); }; } // namespace TEST(JsonRPCTest, InvalidRequestHandling) { - JsonRPC jrpc; + auto jrpc = CreateJsonRPC(); json response; const std::string request = R"!({"jsonrpc: 2.0", "id":0, [method]: "initialize"})!"; const std::string message = BuildRequest(request); - jrpc.RegisterOutputCallback( + jrpc->RegisterOutputCallback( [&response](const std::string& message) { response = json::parse(ParseResponse(message)); }); Send(message, jrpc); const auto code = response["error"]["code"].get(); - EXPECT_EQ(code, static_cast(JsonRPC::ErrorCode::ParseError)); + EXPECT_EQ(code, static_cast(JRPCErrorCode::ParseError)); } TEST(JsonRPCTest, OutOfOrderRequest) { - JsonRPC jrpc; + auto jrpc = CreateJsonRPC(); json response; const std::string message = BuildRequest(json::object({{"jsonrpc", "2.0"}, {"id", 0}, {"method", "textDocument/didOpen"}, {"params", {}}})); - jrpc.RegisterOutputCallback( + jrpc->RegisterOutputCallback( [&response](const std::string& message) { response = json::parse(ParseResponse(message)); }); Send(message, jrpc); const auto code = response["error"]["code"].get(); - EXPECT_EQ(code, static_cast(JsonRPC::ErrorCode::NotInitialized)); + EXPECT_EQ(code, static_cast(JRPCErrorCode::NotInitialized)); } TEST(JsonRPCTest, MethodInitialize) { - JsonRPC jrpc; + auto jrpc = CreateJsonRPC(); int64_t processId = 0; - jrpc.RegisterOutputCallback([](const std::string&) {}); - jrpc.RegisterMethodCallback( + jrpc->RegisterOutputCallback([](const std::string&) {}); + jrpc->RegisterMethodCallback( "initialize", [&processId](const json& request) { processId = request["params"]["processId"].get(); }); Send(initRequest, jrpc); @@ -107,29 +109,29 @@ TEST(JsonRPCTest, MethodInitialize) TEST(JsonRPCTest, RespondToUnsupportedMethod) { - JsonRPC jrpc; + auto jrpc = CreateJsonRPC(); InitializeJsonRPC(jrpc); json response; const std::string request = BuildRequest(json::object({{"jsonrpc", "2.0"}, {"id", 0}, {"method", "textDocument/didOpen"}, {"params", {}}})); - jrpc.RegisterOutputCallback( + jrpc->RegisterOutputCallback( [&response](const std::string& message) { response = json::parse(ParseResponse(message)); }); // send unsupported request Send(request, jrpc); const auto code = response["error"]["code"].get(); - EXPECT_EQ(code, static_cast(JsonRPC::ErrorCode::MethodNotFound)); + EXPECT_EQ(code, static_cast(JRPCErrorCode::MethodNotFound)); } TEST(JsonRPCTest, RespondToSupportedMethod) { - JsonRPC jrpc; + auto jrpc = CreateJsonRPC(); InitializeJsonRPC(jrpc); bool isCallbackCalled = false; const std::string request = BuildRequest(json::object({{"jsonrpc", "2.0"}, {"id", 0}, {"method", "textDocument/didOpen"}, {"params", {}}})); - jrpc.RegisterMethodCallback( + jrpc->RegisterMethodCallback( "textDocument/didOpen", [&isCallbackCalled]([[maybe_unused]] const json& request) { isCallbackCalled = true; }); Send(request, jrpc); From b9074bdd45b7c6bb6442a01312d752510f21cb3e Mon Sep 17 00:00:00 2001 From: Galarius Date: Mon, 24 Jul 2023 23:40:37 +0300 Subject: [PATCH 05/32] Make `diagnostics` and `jrpc` instances injectable into `lsp` --- include/lsp.hpp | 5 ++++- src/lsp.cpp | 10 +++++----- src/main.cpp | 8 ++++++-- 3 files changed, 15 insertions(+), 8 deletions(-) diff --git a/include/lsp.hpp b/include/lsp.hpp index 4bfdb0c..3c6de09 100644 --- a/include/lsp.hpp +++ b/include/lsp.hpp @@ -9,6 +9,9 @@ #include +#include "diagnostics.hpp" +#include "jsonrpc.hpp" + namespace ocls { struct ILSPServer @@ -17,6 +20,6 @@ struct ILSPServer virtual void Interrupt() = 0; }; -std::shared_ptr CreateLSPServer(); +std::shared_ptr CreateLSPServer(std::shared_ptr jrpc, std::shared_ptr diagnostics); } // namespace ocls diff --git a/src/lsp.cpp b/src/lsp.cpp index 1cf6b65..1b60397 100644 --- a/src/lsp.cpp +++ b/src/lsp.cpp @@ -32,9 +32,9 @@ class LSPServer final , public std::enable_shared_from_this { public: - LSPServer() - : m_jrpc(CreateJsonRPC()) - , m_diagnostics(CreateDiagnostics(CreateCLInfo())) {} + LSPServer(std::shared_ptr jrpc, std::shared_ptr diagnostics) + : m_jrpc { std::move(jrpc) } + , m_diagnostics { std::move(diagnostics) } {} int Run(); void Interrupt(); @@ -322,9 +322,9 @@ void LSPServer::Interrupt() m_interrupted.store(true); } -std::shared_ptr CreateLSPServer() +std::shared_ptr CreateLSPServer(std::shared_ptr jrpc, std::shared_ptr diagnostics) { - return std::shared_ptr(new LSPServer()); + return std::shared_ptr(new LSPServer(std::move(jrpc), std::move(diagnostics))); } } // namespace ocls diff --git a/src/main.cpp b/src/main.cpp index 2c98314..5d18724 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -8,6 +8,8 @@ #include #include "clinfo.hpp" +#include "diagnostics.hpp" +#include "jsonrpc.hpp" #include "lsp.hpp" #include "version.hpp" @@ -119,9 +121,9 @@ int main(int argc, char* argv[]) ConfigureLogging(flagLogTofile, optLogFile, optLogLevel); + auto clinfo = CreateCLInfo(); if (flagCLInfo) { - const auto clinfo = CreateCLInfo(); const auto jsonBody = clinfo->json(); std::cout << jsonBody.dump() << std::endl; exit(0); @@ -131,6 +133,8 @@ int main(int argc, char* argv[]) std::signal(SIGINT, SignalHandler); - server = CreateLSPServer(); + auto jrpc = CreateJsonRPC(); + auto diagnostics = CreateDiagnostics(clinfo); + server = CreateLSPServer(jrpc, diagnostics); return server->Run(); } From a2c35e7c5ba05432d7447f7f03f12237af5af7a5 Mon Sep 17 00:00:00 2001 From: Galarius Date: Tue, 25 Jul 2023 21:24:15 +0300 Subject: [PATCH 06/32] Move json-rpc tests to a separate file --- tests/CMakeLists.txt | 11 +++- tests/jsonrpc-tests.cpp | 139 ++++++++++++++++++++++++++++++++++++++++ tests/main.cpp | 133 +------------------------------------- 3 files changed, 149 insertions(+), 134 deletions(-) create mode 100644 tests/jsonrpc-tests.cpp diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 10c3c9c..8a1fa6a 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -1,9 +1,14 @@ set(TESTS_PROJECT_NAME ${PROJECT_NAME}-tests) set(headers - "${PROJECT_SOURCE_DIR}/include/jsonrpc.hpp" + jsonrpc.hpp ) set(sources - "${PROJECT_SOURCE_DIR}/src/jsonrpc.cpp" + jsonrpc.cpp +) +list(TRANSFORM headers PREPEND "${PROJECT_SOURCE_DIR}/include/") +list(TRANSFORM sources PREPEND "${PROJECT_SOURCE_DIR}/src/") +set(test_sources + jsonrpc-tests.cpp main.cpp ) set(libs GTest::gtest nlohmann_json::nlohmann_json spdlog::spdlog OpenCL::HeadersCpp) @@ -15,7 +20,7 @@ elseif(WIN32) set(libs ${libs} OpenCL::OpenCL) endif() -add_executable (${TESTS_PROJECT_NAME} ${headers} ${sources}) +add_executable (${TESTS_PROJECT_NAME} ${headers} ${sources} ${test_sources}) target_include_directories(${TESTS_PROJECT_NAME} PRIVATE "${PROJECT_SOURCE_DIR}/include" ) diff --git a/tests/jsonrpc-tests.cpp b/tests/jsonrpc-tests.cpp new file mode 100644 index 0000000..d26d78a --- /dev/null +++ b/tests/jsonrpc-tests.cpp @@ -0,0 +1,139 @@ +// +// jsonrpc-tests.cpp +// opencl-language-server-tests +// +// Created by Ilia Shoshin on 7/16/21. +// + +#include "diagnostics.hpp" +#include "jsonrpc.hpp" + +#include +#include + + +using namespace ocls; +using namespace nlohmann; + +namespace { + +std::string BuildRequest(const std::string& content) +{ + std::string request; + request.append("Content-Length: " + std::to_string(content.size()) + "\r\n"); + request.append("Content-Type: application/vscode-jsonrpc;charset=utf-8\r\n"); + request.append("\r\n"); + request.append(content); + return request; +} + +std::string BuildRequest(const json& obj) +{ + return BuildRequest(obj.dump()); +} + +void Send(const std::string& request, std::shared_ptr jrpc) +{ + for (auto c : request) + { + jrpc->Consume(c); + } +} + +std::string ParseResponse(std::string str) +{ + std::string delimiter = "\r\n"; + size_t pos = 0; + while ((pos = str.find(delimiter)) != std::string::npos) + { + str.erase(0, pos + delimiter.length()); + } + return str; +} + +const std::string initRequest = BuildRequest(json::object( + {{"jsonrpc", "2.0"}, {"id", 0}, {"method", "initialize"}, {"params", {{"processId", 60650}, {"trace", "off"}}}})); + +const auto InitializeJsonRPC = [](std::shared_ptr jrpc) { + jrpc->RegisterOutputCallback([](const std::string&) {}); + jrpc->RegisterMethodCallback("initialize", [](const json&) {}); + Send(initRequest, jrpc); + jrpc->Reset(); +}; + +} // namespace + +TEST(JsonRPCTest, InvalidRequestHandling) +{ + auto jrpc = CreateJsonRPC(); + json response; + const std::string request = R"!({"jsonrpc: 2.0", "id":0, [method]: "initialize"})!"; + const std::string message = BuildRequest(request); + jrpc->RegisterOutputCallback( + [&response](const std::string& message) { response = json::parse(ParseResponse(message)); }); + + Send(message, jrpc); + + const auto code = response["error"]["code"].get(); + EXPECT_EQ(code, static_cast(JRPCErrorCode::ParseError)); +} + +TEST(JsonRPCTest, OutOfOrderRequest) +{ + auto jrpc = CreateJsonRPC(); + json response; + const std::string message = + BuildRequest(json::object({{"jsonrpc", "2.0"}, {"id", 0}, {"method", "textDocument/didOpen"}, {"params", {}}})); + jrpc->RegisterOutputCallback( + [&response](const std::string& message) { response = json::parse(ParseResponse(message)); }); + + Send(message, jrpc); + + const auto code = response["error"]["code"].get(); + EXPECT_EQ(code, static_cast(JRPCErrorCode::NotInitialized)); +} + +TEST(JsonRPCTest, MethodInitialize) +{ + auto jrpc = CreateJsonRPC(); + int64_t processId = 0; + jrpc->RegisterOutputCallback([](const std::string&) {}); + jrpc->RegisterMethodCallback( + "initialize", [&processId](const json& request) { processId = request["params"]["processId"].get(); }); + + Send(initRequest, jrpc); + + EXPECT_EQ(processId, 60650); +} + +TEST(JsonRPCTest, RespondToUnsupportedMethod) +{ + auto jrpc = CreateJsonRPC(); + InitializeJsonRPC(jrpc); + json response; + const std::string request = + BuildRequest(json::object({{"jsonrpc", "2.0"}, {"id", 0}, {"method", "textDocument/didOpen"}, {"params", {}}})); + jrpc->RegisterOutputCallback( + [&response](const std::string& message) { response = json::parse(ParseResponse(message)); }); + + // send unsupported request + Send(request, jrpc); + + const auto code = response["error"]["code"].get(); + EXPECT_EQ(code, static_cast(JRPCErrorCode::MethodNotFound)); +} + +TEST(JsonRPCTest, RespondToSupportedMethod) +{ + auto jrpc = CreateJsonRPC(); + InitializeJsonRPC(jrpc); + bool isCallbackCalled = false; + const std::string request = + BuildRequest(json::object({{"jsonrpc", "2.0"}, {"id", 0}, {"method", "textDocument/didOpen"}, {"params", {}}})); + jrpc->RegisterMethodCallback( + "textDocument/didOpen", [&isCallbackCalled]([[maybe_unused]] const json& request) { isCallbackCalled = true; }); + + Send(request, jrpc); + + EXPECT_TRUE(isCallbackCalled); +} \ No newline at end of file diff --git a/tests/main.cpp b/tests/main.cpp index b55f23b..1c00956 100644 --- a/tests/main.cpp +++ b/tests/main.cpp @@ -2,143 +2,14 @@ // main.cpp // opencl-language-server-tests // -// Created by Ilya Shoshin (Galarius) on 7/16/21. +// Created by Ilia Shoshin on 7/16/21. // #include - -#include "diagnostics.hpp" -#include "jsonrpc.hpp" -#include +#include #include #include -using namespace ocls; -using namespace nlohmann; - -namespace { - -std::string BuildRequest(const std::string& content) -{ - std::string request; - request.append("Content-Length: " + std::to_string(content.size()) + "\r\n"); - request.append("Content-Type: application/vscode-jsonrpc;charset=utf-8\r\n"); - request.append("\r\n"); - request.append(content); - return request; -} - -std::string BuildRequest(const json& obj) -{ - return BuildRequest(obj.dump()); -} - -void Send(const std::string& request, std::shared_ptr jrpc) -{ - for (auto c : request) - { - jrpc->Consume(c); - } -} - -std::string ParseResponse(std::string str) -{ - std::string delimiter = "\r\n"; - size_t pos = 0; - while ((pos = str.find(delimiter)) != std::string::npos) - { - str.erase(0, pos + delimiter.length()); - } - return str; -} - -const std::string initRequest = BuildRequest(json::object( - {{"jsonrpc", "2.0"}, {"id", 0}, {"method", "initialize"}, {"params", {{"processId", 60650}, {"trace", "off"}}}})); - -const auto InitializeJsonRPC = [](std::shared_ptr jrpc) { - jrpc->RegisterOutputCallback([](const std::string&) {}); - jrpc->RegisterMethodCallback("initialize", [](const json&) {}); - Send(initRequest, jrpc); - jrpc->Reset(); -}; - -} // namespace - -TEST(JsonRPCTest, InvalidRequestHandling) -{ - auto jrpc = CreateJsonRPC(); - json response; - const std::string request = R"!({"jsonrpc: 2.0", "id":0, [method]: "initialize"})!"; - const std::string message = BuildRequest(request); - jrpc->RegisterOutputCallback( - [&response](const std::string& message) { response = json::parse(ParseResponse(message)); }); - - Send(message, jrpc); - - const auto code = response["error"]["code"].get(); - EXPECT_EQ(code, static_cast(JRPCErrorCode::ParseError)); -} - -TEST(JsonRPCTest, OutOfOrderRequest) -{ - auto jrpc = CreateJsonRPC(); - json response; - const std::string message = - BuildRequest(json::object({{"jsonrpc", "2.0"}, {"id", 0}, {"method", "textDocument/didOpen"}, {"params", {}}})); - jrpc->RegisterOutputCallback( - [&response](const std::string& message) { response = json::parse(ParseResponse(message)); }); - - Send(message, jrpc); - - const auto code = response["error"]["code"].get(); - EXPECT_EQ(code, static_cast(JRPCErrorCode::NotInitialized)); -} - -TEST(JsonRPCTest, MethodInitialize) -{ - auto jrpc = CreateJsonRPC(); - int64_t processId = 0; - jrpc->RegisterOutputCallback([](const std::string&) {}); - jrpc->RegisterMethodCallback( - "initialize", [&processId](const json& request) { processId = request["params"]["processId"].get(); }); - - Send(initRequest, jrpc); - - EXPECT_EQ(processId, 60650); -} - -TEST(JsonRPCTest, RespondToUnsupportedMethod) -{ - auto jrpc = CreateJsonRPC(); - InitializeJsonRPC(jrpc); - json response; - const std::string request = - BuildRequest(json::object({{"jsonrpc", "2.0"}, {"id", 0}, {"method", "textDocument/didOpen"}, {"params", {}}})); - jrpc->RegisterOutputCallback( - [&response](const std::string& message) { response = json::parse(ParseResponse(message)); }); - - // send unsupported request - Send(request, jrpc); - - const auto code = response["error"]["code"].get(); - EXPECT_EQ(code, static_cast(JRPCErrorCode::MethodNotFound)); -} - -TEST(JsonRPCTest, RespondToSupportedMethod) -{ - auto jrpc = CreateJsonRPC(); - InitializeJsonRPC(jrpc); - bool isCallbackCalled = false; - const std::string request = - BuildRequest(json::object({{"jsonrpc", "2.0"}, {"id", 0}, {"method", "textDocument/didOpen"}, {"params", {}}})); - jrpc->RegisterMethodCallback( - "textDocument/didOpen", [&isCallbackCalled]([[maybe_unused]] const json& request) { isCallbackCalled = true; }); - - Send(request, jrpc); - - EXPECT_TRUE(isCallbackCalled); -} - int main(int argc, char** argv) { ::testing::InitGoogleTest(&argc, argv); From aab26e890b08587ed4b91e3b081090460d334b28 Mon Sep 17 00:00:00 2001 From: Galarius Date: Tue, 25 Jul 2023 23:26:33 +0300 Subject: [PATCH 07/32] Refactor LSPServer for better testability --- include/diagnostics.hpp | 2 + include/jsonrpc.hpp | 2 + include/lsp.hpp | 31 +++++++- src/lsp.cpp | 126 ++++++++++++++++++++++-------- tests/CMakeLists.txt | 12 +-- tests/lsp-event-handler-tests.cpp | 66 ++++++++++++++++ tests/mocks/diagnostics-mock.hpp | 22 ++++++ tests/mocks/jsonrpc-mock.hpp | 32 ++++++++ 8 files changed, 252 insertions(+), 41 deletions(-) create mode 100644 tests/lsp-event-handler-tests.cpp create mode 100644 tests/mocks/diagnostics-mock.hpp create mode 100644 tests/mocks/jsonrpc-mock.hpp diff --git a/include/diagnostics.hpp b/include/diagnostics.hpp index 01b24b0..399f2cd 100755 --- a/include/diagnostics.hpp +++ b/include/diagnostics.hpp @@ -23,6 +23,8 @@ struct Source struct IDiagnostics { + virtual ~IDiagnostics() = default; + virtual void SetBuildOptions(const nlohmann::json& options) = 0; virtual void SetMaxProblemsCount(int maxNumberOfProblems) = 0; virtual void SetOpenCLDevice(uint32_t identifier) = 0; diff --git a/include/jsonrpc.hpp b/include/jsonrpc.hpp index a29afbd..fbe1fb8 100644 --- a/include/jsonrpc.hpp +++ b/include/jsonrpc.hpp @@ -34,6 +34,8 @@ enum class JRPCErrorCode : int struct IJsonRPC { + virtual ~IJsonRPC() = default; + /** Register callback to be notified on the specific method notification. All unregistered notifications will be responded with MethodNotFound automatically. diff --git a/include/lsp.hpp b/include/lsp.hpp index 3c6de09..9575056 100644 --- a/include/lsp.hpp +++ b/include/lsp.hpp @@ -2,12 +2,13 @@ // lsp.hpp // opencl-language-server // -// Created by Ilya Shoshin (Galarius) on 7/16/21. +// Created by Ilia Shoshin on 7/16/21. // #pragma once #include +#include #include "diagnostics.hpp" #include "jsonrpc.hpp" @@ -16,10 +17,36 @@ namespace ocls { struct ILSPServer { + virtual ~ILSPServer() = default; + virtual int Run() = 0; virtual void Interrupt() = 0; }; -std::shared_ptr CreateLSPServer(std::shared_ptr jrpc, std::shared_ptr diagnostics); +struct ILSPServerEventsHandler +{ + virtual ~ILSPServerEventsHandler() = default; + + virtual void BuildDiagnosticsRespond(const std::string &uri, const std::string &content) = 0; + virtual void GetConfiguration() = 0; + virtual std::optional GetNextResponse() = 0; + virtual void OnInitialize(const nlohmann::json &data) = 0; + virtual void OnInitialized(const nlohmann::json &data) = 0; + virtual void OnTextOpen(const nlohmann::json &data) = 0; + virtual void OnTextChanged(const nlohmann::json &data) = 0; + virtual void OnConfiguration(const nlohmann::json &data) = 0; + virtual void OnRespond(const nlohmann::json &data) = 0; + virtual void OnShutdown(const nlohmann::json &data) = 0; + virtual void OnExit() = 0; +}; + +std::shared_ptr CreateLSPEventsHandler( + std::shared_ptr jrpc, std::shared_ptr diagnostics); + +std::shared_ptr CreateLSPServer( + std::shared_ptr jrpc, std::shared_ptr handler); + +std::shared_ptr CreateLSPServer( + std::shared_ptr jrpc, std::shared_ptr diagnostics); } // namespace ocls diff --git a/src/lsp.cpp b/src/lsp.cpp index 1b60397..3513677 100644 --- a/src/lsp.cpp +++ b/src/lsp.cpp @@ -2,7 +2,7 @@ // lsp.cpp // opencl-language-server // -// Created by Ilya Shoshin (Galarius) on 7/16/21. +// Created by Ilia Shoshin on 7/16/21. // #include "lsp.hpp" @@ -32,9 +32,10 @@ class LSPServer final , public std::enable_shared_from_this { public: - LSPServer(std::shared_ptr jrpc, std::shared_ptr diagnostics) - : m_jrpc { std::move(jrpc) } - , m_diagnostics { std::move(diagnostics) } {} + LSPServer(std::shared_ptr jrpc, std::shared_ptr handler) + : m_jrpc {std::move(jrpc)} + , m_handler {std::move(handler)} + {} int Run(); void Interrupt(); @@ -51,6 +52,32 @@ class LSPServer final void OnShutdown(const json &data); void OnExit(); +private: + std::shared_ptr m_jrpc; + std::shared_ptr m_handler; + std::atomic m_interrupted = {false}; +}; + +class LSPServerEventsHandler final : public ILSPServerEventsHandler +{ +public: + LSPServerEventsHandler(std::shared_ptr jrpc, std::shared_ptr diagnostics) + : m_jrpc {std::move(jrpc)} + , m_diagnostics {std::move(diagnostics)} + {} + + void BuildDiagnosticsRespond(const std::string &uri, const std::string &content); + void GetConfiguration(); + std::optional GetNextResponse(); + void OnInitialize(const json &data); + void OnInitialized(const json &data); + void OnTextOpen(const json &data); + void OnTextChanged(const json &data); + void OnConfiguration(const json &data); + void OnRespond(const json &data); + void OnShutdown(const json &data); + void OnExit(); + private: std::shared_ptr m_jrpc; std::shared_ptr m_diagnostics; @@ -58,16 +85,18 @@ class LSPServer final Capabilities m_capabilities; std::queue> m_requests; bool m_shutdown = false; - std::atomic m_interrupted = {false}; }; -void LSPServer::GetConfiguration() +// ILSPServerEventsHandler + +void LSPServerEventsHandler::GetConfiguration() { if (!m_capabilities.hasConfigurationCapability) { spdlog::get(logger)->debug("Does not have configuration capability"); return; } + spdlog::get(logger)->debug("Make configuration request"); json buildOptions = {{"section", "OpenCL.server.buildOptions"}}; json maxNumberOfProblems = {{"section", "OpenCL.server.maxNumberOfProblems"}}; @@ -80,8 +109,19 @@ void LSPServer::GetConfiguration() {"params", {{"items", json::array({buildOptions, maxNumberOfProblems, openCLDeviceID})}}}}); } +std::optional LSPServerEventsHandler::GetNextResponse() +{ + if (m_outQueue.empty()) + { + return std::nullopt; + } + + auto data = m_outQueue.front(); + m_outQueue.pop(); + return data; +} -void LSPServer::OnInitialize(const json &data) +void LSPServerEventsHandler::OnInitialize(const json &data) { spdlog::get(logger)->debug("Received 'initialize' request"); try @@ -101,7 +141,7 @@ void LSPServer::OnInitialize(const json &data) auto deviceID = configuration["deviceID"].get(); m_diagnostics->SetOpenCLDevice(static_cast(deviceID)); } - catch (std::exception &err) + catch (nlohmann::json::out_of_range &err) { spdlog::get(logger)->error("Failed to parse initialize parameters, {}", err.what()); } @@ -120,7 +160,7 @@ void LSPServer::OnInitialize(const json &data) m_outQueue.push({{"id", data["id"]}, {"result", {{"capabilities", capabilities}}}}); } -void LSPServer::OnInitialized(const json &) +void LSPServerEventsHandler::OnInitialized(const json &) { spdlog::get(logger)->debug("Received 'initialized' message"); if (!m_capabilities.supportDidChangeConfiguration) @@ -141,7 +181,7 @@ void LSPServer::OnInitialized(const json &) m_outQueue.push({{"id", utils::GenerateId()}, {"method", "client/registerCapability"}, {"params", params}}); } -void LSPServer::BuildDiagnosticsRespond(const std::string &uri, const std::string &content) +void LSPServerEventsHandler::BuildDiagnosticsRespond(const std::string &uri, const std::string &content) { try { @@ -165,7 +205,7 @@ void LSPServer::BuildDiagnosticsRespond(const std::string &uri, const std::strin } } -void LSPServer::OnTextOpen(const json &data) +void LSPServerEventsHandler::OnTextOpen(const json &data) { spdlog::get(logger)->debug("Received 'textOpen' message"); std::string srcUri = data["params"]["textDocument"]["uri"].get(); @@ -173,7 +213,7 @@ void LSPServer::OnTextOpen(const json &data) BuildDiagnosticsRespond(srcUri, content); } -void LSPServer::OnTextChanged(const json &data) +void LSPServerEventsHandler::OnTextChanged(const json &data) { spdlog::get(logger)->debug("Received 'textChanged' message"); std::string srcUri = data["params"]["textDocument"]["uri"].get(); @@ -182,7 +222,7 @@ void LSPServer::OnTextChanged(const json &data) BuildDiagnosticsRespond(srcUri, content); } -void LSPServer::OnConfiguration(const json &data) +void LSPServerEventsHandler::OnConfiguration(const json &data) { spdlog::get(logger)->debug("Received 'configuration' respond"); auto result = data["result"]; @@ -215,7 +255,7 @@ void LSPServer::OnConfiguration(const json &data) } } -void LSPServer::OnRespond(const json &data) +void LSPServerEventsHandler::OnRespond(const json &data) { spdlog::get(logger)->debug("Received client respond"); const auto id = data["id"]; @@ -228,14 +268,14 @@ void LSPServer::OnRespond(const json &data) } } -void LSPServer::OnShutdown(const json &data) +void LSPServerEventsHandler::OnShutdown(const json &data) { spdlog::get(logger)->debug("Received 'shutdown' request"); m_outQueue.push({{"id", data["id"]}, {"result", nullptr}}); m_shutdown = true; } -void LSPServer::OnExit() +void LSPServerEventsHandler::OnExit() { spdlog::get(logger)->debug("Received 'exit', after 'shutdown': {}", m_shutdown ? "yes" : "no"); if (m_shutdown) @@ -244,6 +284,8 @@ void LSPServer::OnExit() exit(EXIT_FAILURE); } +// ILSPServer + int LSPServer::Run() { spdlog::get(logger)->info("Setting up..."); @@ -252,36 +294,36 @@ int LSPServer::Run() // Register handlers for methods m_jrpc->RegisterMethodCallback("initialize", [self](const json &request) { - self->OnInitialize(request); + self->m_handler->OnInitialize(request); }); m_jrpc->RegisterMethodCallback("initialized", [self](const json &request) { - self->OnInitialized(request); + self->m_handler->OnInitialized(request); }); m_jrpc->RegisterMethodCallback("shutdown", [self](const json &request) { - self->OnShutdown(request); + self->m_handler->OnShutdown(request); }); m_jrpc->RegisterMethodCallback("exit", [self](const json &) { - self->OnExit(); + self->m_handler->OnExit(); }); m_jrpc->RegisterMethodCallback("textDocument/didOpen", [self](const json &request) { - self->OnTextOpen(request); + self->m_handler->OnTextOpen(request); }); m_jrpc->RegisterMethodCallback("textDocument/didChange", [self](const json &request) { - self->OnTextChanged(request); + self->m_handler->OnTextChanged(request); }); m_jrpc->RegisterMethodCallback("workspace/didChangeConfiguration", [self](const json &) { - self->GetConfiguration(); + self->m_handler->GetConfiguration(); }); // Register handler for client responds m_jrpc->RegisterInputCallback([self](const json &respond) { - self->OnRespond(respond); + self->m_handler->OnRespond(respond); }); // Register handler for message delivery m_jrpc->RegisterOutputCallback([](const std::string &message) @@ -294,37 +336,55 @@ int LSPServer::Run() #endif }); // clang-format on - + spdlog::get(logger)->info("Listening..."); char c; while (std::cin.get(c)) { - if(m_interrupted.load()) { + if (m_interrupted.load()) + { return EINTR; } m_jrpc->Consume(c); if (m_jrpc->IsReady()) { m_jrpc->Reset(); - while (!m_outQueue.empty()) + while (true) { - auto data = m_outQueue.front(); - m_jrpc->Write(data); - m_outQueue.pop(); + auto data = m_handler->GetNextResponse(); + if (!data.has_value()) + { + break; + } + m_jrpc->Write(*data); } } } return 0; } -void LSPServer::Interrupt() +void LSPServer::Interrupt() { m_interrupted.store(true); } -std::shared_ptr CreateLSPServer(std::shared_ptr jrpc, std::shared_ptr diagnostics) +std::shared_ptr CreateLSPEventsHandler( + std::shared_ptr jrpc, std::shared_ptr diagnostics) +{ + return std::make_shared(jrpc, diagnostics); +} + +std::shared_ptr CreateLSPServer( + std::shared_ptr jrpc, std::shared_ptr handler) +{ + return std::make_shared(std::move(jrpc), std::move(handler)); +} + +std::shared_ptr CreateLSPServer( + std::shared_ptr jrpc, std::shared_ptr diagnostics) { - return std::shared_ptr(new LSPServer(std::move(jrpc), std::move(diagnostics))); + auto handler = std::make_shared(jrpc, diagnostics); + return std::make_shared(std::move(jrpc), std::move(handler)); } } // namespace ocls diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 8a1fa6a..74f8cba 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -1,17 +1,16 @@ set(TESTS_PROJECT_NAME ${PROJECT_NAME}-tests) -set(headers - jsonrpc.hpp -) set(sources jsonrpc.cpp + lsp.cpp + utils.cpp ) -list(TRANSFORM headers PREPEND "${PROJECT_SOURCE_DIR}/include/") list(TRANSFORM sources PREPEND "${PROJECT_SOURCE_DIR}/src/") set(test_sources jsonrpc-tests.cpp + lsp-event-handler-tests.cpp main.cpp ) -set(libs GTest::gtest nlohmann_json::nlohmann_json spdlog::spdlog OpenCL::HeadersCpp) +set(libs GTest::gmock nlohmann_json::nlohmann_json spdlog::spdlog OpenCL::HeadersCpp) if(LINUX) set(libs ${libs} stdc++fs OpenCL::OpenCL) elseif(APPLE) @@ -20,9 +19,10 @@ elseif(WIN32) set(libs ${libs} OpenCL::OpenCL) endif() -add_executable (${TESTS_PROJECT_NAME} ${headers} ${sources} ${test_sources}) +add_executable (${TESTS_PROJECT_NAME} ${sources} ${test_sources}) target_include_directories(${TESTS_PROJECT_NAME} PRIVATE "${PROJECT_SOURCE_DIR}/include" + "${CMAKE_CURRENT_SOURCE_DIR}/mocks" ) if(APPLE) target_include_directories(${TESTS_PROJECT_NAME} PRIVATE "${OpenCL_INCLUDE_DIRS}") diff --git a/tests/lsp-event-handler-tests.cpp b/tests/lsp-event-handler-tests.cpp new file mode 100644 index 0000000..485a4eb --- /dev/null +++ b/tests/lsp-event-handler-tests.cpp @@ -0,0 +1,66 @@ +// +// lsp-event-handler-tests.cpp +// opencl-language-server-tests +// +// Created by Ilia Shoshin on 7/25/23. +// + +#include "lsp.hpp" +#include "jsonrpc-mock.hpp" +#include "diagnostics-mock.hpp" + +#include +#include + +using namespace ocls; +using namespace nlohmann; + +namespace { + +class LSPTest : public ::testing::Test +{ +protected: + std::shared_ptr mockJsonRPC; + std::shared_ptr mockDiagnostics; + std::shared_ptr handler; + + void SetUp() override + { + mockJsonRPC = std::make_shared(); + mockDiagnostics = std::make_shared(); + handler = CreateLSPEventsHandler(mockJsonRPC, mockDiagnostics); + } +}; + +} // namespace + + +TEST_F(LSPTest, OnInitialize) +{ + nlohmann::json testData = R"({ + "params": { + "capabilities": { + "workspace": { + "configuration": true, + "didChangeConfiguration": {"dynamicRegistration": true} + } + }, + "initializationOptions": { + "configuration": { + "buildOptions": { "option": "value" }, + "maxNumberOfProblems": 10, + "deviceID": 1 + } + } + }, + "id": 1 + })"_json; + + EXPECT_CALL( + *mockDiagnostics, SetBuildOptions(testData["params"]["initializationOptions"]["configuration"]["buildOptions"])) + .Times(1); + EXPECT_CALL(*mockDiagnostics, SetMaxProblemsCount(10)).Times(1); + EXPECT_CALL(*mockDiagnostics, SetOpenCLDevice(1)).Times(1); + + handler->OnInitialize(testData); +} diff --git a/tests/mocks/diagnostics-mock.hpp b/tests/mocks/diagnostics-mock.hpp new file mode 100644 index 0000000..8700d5d --- /dev/null +++ b/tests/mocks/diagnostics-mock.hpp @@ -0,0 +1,22 @@ +// +// diagnostics-mock.hpp +// opencl-language-server-tests +// +// Created by Ilia Shoshin on 7/25/23. +// + +#include "diagnostics.hpp" + +#include + +class DiagnosticsMock : public ocls::IDiagnostics +{ +public: + MOCK_METHOD(void, SetBuildOptions, (const nlohmann::json&), (override)); + + MOCK_METHOD(void, SetMaxProblemsCount, (int), (override)); + + MOCK_METHOD(void, SetOpenCLDevice, (uint32_t), (override)); + + MOCK_METHOD(nlohmann::json, Get, (const ocls::Source&), (override)); +}; diff --git a/tests/mocks/jsonrpc-mock.hpp b/tests/mocks/jsonrpc-mock.hpp new file mode 100644 index 0000000..cfdcb38 --- /dev/null +++ b/tests/mocks/jsonrpc-mock.hpp @@ -0,0 +1,32 @@ +// +// jsonrpc-mock.hpp +// opencl-language-server-tests +// +// Created by Ilia Shoshin on 7/25/23. +// + +#include "jsonrpc.hpp" + +#include + +class JsonRPCMock final : public ocls::IJsonRPC +{ +public: + MOCK_METHOD(void, RegisterMethodCallback, (const std::string&, ocls::InputCallbackFunc&&), (override)); + + MOCK_METHOD(void, RegisterInputCallback, (ocls::InputCallbackFunc &&), (override)); + + MOCK_METHOD(void, RegisterOutputCallback, (ocls::OutputCallbackFunc &&), (override)); + + MOCK_METHOD(void, Consume, (char), (override)); + + MOCK_METHOD(bool, IsReady, (), (const, override)); + + MOCK_METHOD(void, Write, (const nlohmann::json&), (const, override)); + + MOCK_METHOD(void, Reset, (), (override)); + + MOCK_METHOD(void, WriteTrace, (const std::string&, const std::string&), (override)); + + MOCK_METHOD(void, WriteError, (ocls::JRPCErrorCode, const std::string&), (const, override)); +}; \ No newline at end of file From c1af8535f1f5c90d20a7da5878f76a5c86d62cfd Mon Sep 17 00:00:00 2001 From: Galarius Date: Sat, 29 Jul 2023 22:47:38 +0300 Subject: [PATCH 08/32] Refactor LSPServerEventsHandler::OnInitialize and add tests --- README.md | 2 +- include/diagnostics.hpp | 2 +- src/diagnostics.cpp | 8 ++-- src/lsp.cpp | 65 ++++++++++++++++++++++--------- tests/lsp-event-handler-tests.cpp | 62 ++++++++++++++++++++++++++++- tests/mocks/diagnostics-mock.hpp | 2 +- 6 files changed, 115 insertions(+), 26 deletions(-) diff --git a/README.md b/README.md index 9922470..be9e8dc 100644 --- a/README.md +++ b/README.md @@ -22,7 +22,7 @@ You can configure diagnostics with `json-rpc` request during the intitialization "configuration": { "buildOptions": [], "deviceID": 0, - "maxNumberOfProblems": 100 + "maxNumberOfProblems": 127 } } } diff --git a/include/diagnostics.hpp b/include/diagnostics.hpp index 399f2cd..2f7b8fb 100755 --- a/include/diagnostics.hpp +++ b/include/diagnostics.hpp @@ -26,7 +26,7 @@ struct IDiagnostics virtual ~IDiagnostics() = default; virtual void SetBuildOptions(const nlohmann::json& options) = 0; - virtual void SetMaxProblemsCount(int maxNumberOfProblems) = 0; + virtual void SetMaxProblemsCount(uint64_t maxNumberOfProblems) = 0; virtual void SetOpenCLDevice(uint32_t identifier) = 0; virtual nlohmann::json Get(const Source& source) = 0; }; diff --git a/src/diagnostics.cpp b/src/diagnostics.cpp index 74c9aba..f1bf002 100644 --- a/src/diagnostics.cpp +++ b/src/diagnostics.cpp @@ -63,7 +63,7 @@ class Diagnostics final : public IDiagnostics explicit Diagnostics(std::shared_ptr clInfo); void SetBuildOptions(const nlohmann::json& options); - void SetMaxProblemsCount(int maxNumberOfProblems); + void SetMaxProblemsCount(uint64_t maxNumberOfProblems); void SetOpenCLDevice(uint32_t identifier); nlohmann::json Get(const Source& source); @@ -76,7 +76,7 @@ class Diagnostics final : public IDiagnostics std::optional m_device; std::regex m_regex {"^(.*):(\\d+):(\\d+): ((fatal )?error|warning|Scholar): (.*)$"}; std::string m_BuildOptions; - int m_maxNumberOfProblems = 100; + uint64_t m_maxNumberOfProblems = INT8_MAX; }; Diagnostics::Diagnostics(std::shared_ptr clInfo) : m_clInfo {std::move(clInfo)} @@ -201,7 +201,7 @@ nlohmann::json Diagnostics::BuildDiagnostics(const std::string& buildLog, const std::smatch matches; auto errorLines = utils::SplitString(buildLog, "\n"); json diagnostics; - int count = 0; + uint64_t count = 0; for (auto errLine : errorLines) { std::regex_search(errLine, matches, m_regex); @@ -279,7 +279,7 @@ void Diagnostics::SetBuildOptions(const json& options) } } -void Diagnostics::SetMaxProblemsCount(int maxNumberOfProblems) +void Diagnostics::SetMaxProblemsCount(uint64_t maxNumberOfProblems) { spdlog::get(logger)->trace("Set max number of problems: {}", maxNumberOfProblems); m_maxNumberOfProblems = maxNumberOfProblems; diff --git a/src/lsp.cpp b/src/lsp.cpp index 3513677..bc5b65a 100644 --- a/src/lsp.cpp +++ b/src/lsp.cpp @@ -17,6 +17,24 @@ using namespace nlohmann; +namespace { + +std::optional GetNestedValue(const nlohmann::json &j, const std::vector &keys) +{ + const nlohmann::json *current = &j; + for (const auto &key : keys) + { + if (!current->contains(key)) + { + return std::nullopt; + } + current = &(*current)[key]; + } + return *current; +} + +} // namespace + namespace ocls { constexpr char logger[] = "lsp"; @@ -124,26 +142,38 @@ std::optional LSPServerEventsHandler::GetNextResponse() void LSPServerEventsHandler::OnInitialize(const json &data) { spdlog::get(logger)->debug("Received 'initialize' request"); - try - { - m_capabilities.hasConfigurationCapability = - data["params"]["capabilities"]["workspace"]["configuration"].get(); - m_capabilities.supportDidChangeConfiguration = - data["params"]["capabilities"]["workspace"]["didChangeConfiguration"]["dynamicRegistration"].get(); - auto configuration = data["params"]["initializationOptions"]["configuration"]; - - auto buildOptions = configuration["buildOptions"]; - m_diagnostics->SetBuildOptions(buildOptions); - auto maxNumberOfProblems = configuration["maxNumberOfProblems"].get(); - m_diagnostics->SetMaxProblemsCount(static_cast(maxNumberOfProblems)); + auto configurationCapability = GetNestedValue(data, {"params", "capabilities", "workspace", "configuration"}); + if (configurationCapability) + { + m_capabilities.hasConfigurationCapability = configurationCapability->get(); + } - auto deviceID = configuration["deviceID"].get(); - m_diagnostics->SetOpenCLDevice(static_cast(deviceID)); + auto didChangeConfiguration = + GetNestedValue(data, {"params", "capabilities", "workspace", "didChangeConfiguration", "dynamicRegistration"}); + if (didChangeConfiguration) + { + m_capabilities.supportDidChangeConfiguration = didChangeConfiguration->get(); } - catch (nlohmann::json::out_of_range &err) + + auto configuration = GetNestedValue(data, {"params", "initializationOptions", "configuration"}); + if (configuration) { - spdlog::get(logger)->error("Failed to parse initialize parameters, {}", err.what()); + auto buildOptions = GetNestedValue(*configuration, {"buildOptions"}); + auto maxNumberOfProblems = GetNestedValue(*configuration, {"maxNumberOfProblems"}); + auto deviceID = GetNestedValue(*configuration, {"deviceID"}); + if (buildOptions) + { + m_diagnostics->SetBuildOptions(*buildOptions); + } + if (maxNumberOfProblems) + { + m_diagnostics->SetMaxProblemsCount(*maxNumberOfProblems); + } + if (deviceID) + { + m_diagnostics->SetOpenCLDevice(*deviceID); + } } json capabilities = { @@ -380,8 +410,7 @@ std::shared_ptr CreateLSPServer( return std::make_shared(std::move(jrpc), std::move(handler)); } -std::shared_ptr CreateLSPServer( - std::shared_ptr jrpc, std::shared_ptr diagnostics) +std::shared_ptr CreateLSPServer(std::shared_ptr jrpc, std::shared_ptr diagnostics) { auto handler = std::make_shared(jrpc, diagnostics); return std::make_shared(std::move(jrpc), std::move(handler)); diff --git a/tests/lsp-event-handler-tests.cpp b/tests/lsp-event-handler-tests.cpp index 485a4eb..f81cf30 100644 --- a/tests/lsp-event-handler-tests.cpp +++ b/tests/lsp-event-handler-tests.cpp @@ -35,7 +35,7 @@ class LSPTest : public ::testing::Test } // namespace -TEST_F(LSPTest, OnInitialize) +TEST_F(LSPTest, OnInitialize_shouldBuildResponse_andCallDiagnosticsSetters) { nlohmann::json testData = R"({ "params": { @@ -56,6 +56,21 @@ TEST_F(LSPTest, OnInitialize) "id": 1 })"_json; + nlohmann::json expectedResponse = R"({ + "id": 1, + "result": { + "capabilities": { + "textDocumentSync": { + "openClose": true, + "change": 1, + "willSave": false, + "willSaveWaitUntil": false, + "save": false + } + } + } + })"_json; + EXPECT_CALL( *mockDiagnostics, SetBuildOptions(testData["params"]["initializationOptions"]["configuration"]["buildOptions"])) .Times(1); @@ -63,4 +78,49 @@ TEST_F(LSPTest, OnInitialize) EXPECT_CALL(*mockDiagnostics, SetOpenCLDevice(1)).Times(1); handler->OnInitialize(testData); + auto response = handler->GetNextResponse(); + + EXPECT_TRUE(response.has_value()); + EXPECT_EQ(*response, expectedResponse); +} + +TEST_F(LSPTest, OnInitialize_withMissingConfigurationFields_shouldBuildResponse_andNotCallDiagnosticsSetters) +{ + nlohmann::json testData = R"({ + "params": { + "capabilities": { + "workspace": { + "configuration": true, + "didChangeConfiguration": {"dynamicRegistration": true} + } + }, + "initializationOptions": { + "configuration": { + + } + } + }, + "id": 1 + })"_json; + + nlohmann::json expectedResponse = R"({ + "id": 1, + "result": { + "capabilities": { + "textDocumentSync": { + "openClose": true, + "change": 1, + "willSave": false, + "willSaveWaitUntil": false, + "save": false + } + } + } + })"_json; + + handler->OnInitialize(testData); + auto response = handler->GetNextResponse(); + + EXPECT_TRUE(response.has_value()); + EXPECT_EQ(*response, expectedResponse); } diff --git a/tests/mocks/diagnostics-mock.hpp b/tests/mocks/diagnostics-mock.hpp index 8700d5d..a85c4cb 100644 --- a/tests/mocks/diagnostics-mock.hpp +++ b/tests/mocks/diagnostics-mock.hpp @@ -14,7 +14,7 @@ class DiagnosticsMock : public ocls::IDiagnostics public: MOCK_METHOD(void, SetBuildOptions, (const nlohmann::json&), (override)); - MOCK_METHOD(void, SetMaxProblemsCount, (int), (override)); + MOCK_METHOD(void, SetMaxProblemsCount, (uint64_t), (override)); MOCK_METHOD(void, SetOpenCLDevice, (uint32_t), (override)); From d11aa7d888be3606f6afe8c9bf7c5d2cdac5b5eb Mon Sep 17 00:00:00 2001 From: Galarius Date: Sun, 30 Jul 2023 23:35:38 +0300 Subject: [PATCH 09/32] Refactor OnInitialized, optimize id generation, add tests --- include/lsp.hpp | 3 +- include/utils.hpp | 10 +++- src/lsp.cpp | 41 +++++++++++---- src/utils.cpp | 37 +++++++++----- tests/lsp-event-handler-tests.cpp | 83 +++++++++++++++++++++++++++++-- tests/mocks/generator-mock.hpp | 17 +++++++ 6 files changed, 163 insertions(+), 28 deletions(-) create mode 100644 tests/mocks/generator-mock.hpp diff --git a/include/lsp.hpp b/include/lsp.hpp index 9575056..3bee0a4 100644 --- a/include/lsp.hpp +++ b/include/lsp.hpp @@ -12,6 +12,7 @@ #include "diagnostics.hpp" #include "jsonrpc.hpp" +#include "utils.hpp" namespace ocls { @@ -41,7 +42,7 @@ struct ILSPServerEventsHandler }; std::shared_ptr CreateLSPEventsHandler( - std::shared_ptr jrpc, std::shared_ptr diagnostics); + std::shared_ptr jrpc, std::shared_ptr diagnostics, std::shared_ptr generator); std::shared_ptr CreateLSPServer( std::shared_ptr jrpc, std::shared_ptr handler); diff --git a/include/utils.hpp b/include/utils.hpp index 6840bb7..5563bff 100644 --- a/include/utils.hpp +++ b/include/utils.hpp @@ -16,7 +16,15 @@ namespace ocls::utils { -std::string GenerateId(); +struct IGenerator +{ + virtual ~IGenerator() = default; + + virtual std::string GenerateID() = 0; +}; + +std::shared_ptr CreateDefaultGenerator(); + void Trim(std::string& s); std::vector SplitString(const std::string& str, const std::string& pattern); std::string UriToPath(const std::string& uri); diff --git a/src/lsp.cpp b/src/lsp.cpp index bc5b65a..845d49c 100644 --- a/src/lsp.cpp +++ b/src/lsp.cpp @@ -8,7 +8,6 @@ #include "lsp.hpp" #include "diagnostics.hpp" #include "jsonrpc.hpp" -#include "utils.hpp" #include #include @@ -79,9 +78,13 @@ class LSPServer final class LSPServerEventsHandler final : public ILSPServerEventsHandler { public: - LSPServerEventsHandler(std::shared_ptr jrpc, std::shared_ptr diagnostics) + LSPServerEventsHandler( + std::shared_ptr jrpc, + std::shared_ptr diagnostics, + std::shared_ptr generator) : m_jrpc {std::move(jrpc)} , m_diagnostics {std::move(diagnostics)} + , m_generator {std::move(generator)} {} void BuildDiagnosticsRespond(const std::string &uri, const std::string &content); @@ -99,6 +102,7 @@ class LSPServerEventsHandler final : public ILSPServerEventsHandler private: std::shared_ptr m_jrpc; std::shared_ptr m_diagnostics; + std::shared_ptr m_generator; std::queue m_outQueue; Capabilities m_capabilities; std::queue> m_requests; @@ -119,7 +123,7 @@ void LSPServerEventsHandler::GetConfiguration() json buildOptions = {{"section", "OpenCL.server.buildOptions"}}; json maxNumberOfProblems = {{"section", "OpenCL.server.maxNumberOfProblems"}}; json openCLDeviceID = {{"section", "OpenCL.server.deviceID"}}; - const auto requestId = utils::GenerateId(); + const auto requestId = m_generator->GenerateID(); m_requests.push(std::make_pair("workspace/configuration", requestId)); m_outQueue.push( {{"id", requestId}, @@ -142,6 +146,12 @@ std::optional LSPServerEventsHandler::GetNextResponse() void LSPServerEventsHandler::OnInitialize(const json &data) { spdlog::get(logger)->debug("Received 'initialize' request"); + if (!data.contains("id")) + { + spdlog::get(logger)->error("'initialize' message does not contain 'id'"); + return; + } + auto requestId = data["id"]; auto configurationCapability = GetNestedValue(data, {"params", "capabilities", "workspace", "configuration"}); if (configurationCapability) @@ -187,12 +197,20 @@ void LSPServerEventsHandler::OnInitialize(const json &data) }}, }; - m_outQueue.push({{"id", data["id"]}, {"result", {{"capabilities", capabilities}}}}); + m_outQueue.push({{"id", requestId}, {"result", {{"capabilities", capabilities}}}}); } -void LSPServerEventsHandler::OnInitialized(const json &) +void LSPServerEventsHandler::OnInitialized(const json &data) { spdlog::get(logger)->debug("Received 'initialized' message"); + if (!data.contains("id")) + { + spdlog::get(logger)->error("'initialized' message does not contain 'id'"); + return; + } + + auto requestId = data["id"]; + if (!m_capabilities.supportDidChangeConfiguration) { spdlog::get(logger)->debug("Does not support didChangeConfiguration registration"); @@ -200,7 +218,7 @@ void LSPServerEventsHandler::OnInitialized(const json &) } json registrations = {{ - {"id", utils::GenerateId()}, + {"id", m_generator->GenerateID()}, {"method", "workspace/didChangeConfiguration"}, }}; @@ -208,7 +226,7 @@ void LSPServerEventsHandler::OnInitialized(const json &) {"registrations", registrations}, }; - m_outQueue.push({{"id", utils::GenerateId()}, {"method", "client/registerCapability"}, {"params", params}}); + m_outQueue.push({{"id", requestId}, {"method", "client/registerCapability"}, {"params", params}}); } void LSPServerEventsHandler::BuildDiagnosticsRespond(const std::string &uri, const std::string &content) @@ -399,9 +417,11 @@ void LSPServer::Interrupt() } std::shared_ptr CreateLSPEventsHandler( - std::shared_ptr jrpc, std::shared_ptr diagnostics) + std::shared_ptr jrpc, + std::shared_ptr diagnostics, + std::shared_ptr generator) { - return std::make_shared(jrpc, diagnostics); + return std::make_shared(jrpc, diagnostics, generator); } std::shared_ptr CreateLSPServer( @@ -412,7 +432,8 @@ std::shared_ptr CreateLSPServer( std::shared_ptr CreateLSPServer(std::shared_ptr jrpc, std::shared_ptr diagnostics) { - auto handler = std::make_shared(jrpc, diagnostics); + auto generator = utils::CreateDefaultGenerator(); + auto handler = std::make_shared(jrpc, diagnostics, generator); return std::make_shared(std::move(jrpc), std::move(handler)); } diff --git a/src/utils.cpp b/src/utils.cpp index 3660e4c..c3bccb9 100644 --- a/src/utils.cpp +++ b/src/utils.cpp @@ -9,6 +9,7 @@ #include #include +#include #include #include #include @@ -16,22 +17,32 @@ namespace ocls::utils { -std::string GenerateId() +class DefaultGenerator final : public IGenerator { - std::random_device rd; - std::mt19937 gen(rd()); - std::uniform_int_distribution<> dis(0, 255); - std::string identifier; - std::stringstream hex; - for (auto i = 0; i < 16; ++i) +public: + DefaultGenerator() : gen(rd()), dis(0, 255) {} + + std::string GenerateID() { - const auto rc = dis(gen); - hex << std::hex << rc; - auto str = hex.str(); - identifier.append(str.length() < 2 ? '0' + str : str); - hex.str(std::string()); + std::stringstream hex; + hex << std::hex; + for (auto i = 0; i < 16; ++i) + { + hex << std::setw(2) << std::setfill('0') << dis(gen); + } + return hex.str(); } - return identifier; + +private: + std::random_device rd; + std::mt19937 gen; + std::uniform_int_distribution<> dis; +}; + + +std::shared_ptr CreateDefaultGenerator() +{ + return std::make_shared(); } void Trim(std::string& s) diff --git a/tests/lsp-event-handler-tests.cpp b/tests/lsp-event-handler-tests.cpp index f81cf30..c9a7dad 100644 --- a/tests/lsp-event-handler-tests.cpp +++ b/tests/lsp-event-handler-tests.cpp @@ -6,8 +6,10 @@ // #include "lsp.hpp" +#include "utils.hpp" #include "jsonrpc-mock.hpp" #include "diagnostics-mock.hpp" +#include "generator-mock.hpp" #include #include @@ -22,19 +24,26 @@ class LSPTest : public ::testing::Test protected: std::shared_ptr mockJsonRPC; std::shared_ptr mockDiagnostics; + std::shared_ptr mockGenerator; std::shared_ptr handler; void SetUp() override { mockJsonRPC = std::make_shared(); mockDiagnostics = std::make_shared(); - handler = CreateLSPEventsHandler(mockJsonRPC, mockDiagnostics); + mockGenerator = std::make_shared(); + + ON_CALL(*mockGenerator, GenerateID()).WillByDefault(::testing::Return("12345678")); + + handler = CreateLSPEventsHandler(mockJsonRPC, mockDiagnostics, mockGenerator); } }; } // namespace +// OnInitialize + TEST_F(LSPTest, OnInitialize_shouldBuildResponse_andCallDiagnosticsSetters) { nlohmann::json testData = R"({ @@ -100,11 +109,11 @@ TEST_F(LSPTest, OnInitialize_withMissingConfigurationFields_shouldBuildResponse_ } } }, - "id": 1 + "id": "1" })"_json; nlohmann::json expectedResponse = R"({ - "id": 1, + "id": "1", "result": { "capabilities": { "textDocumentSync": { @@ -124,3 +133,71 @@ TEST_F(LSPTest, OnInitialize_withMissingConfigurationFields_shouldBuildResponse_ EXPECT_TRUE(response.has_value()); EXPECT_EQ(*response, expectedResponse); } + +// OnInitialized + +TEST_F(LSPTest, OnInitialized_withDidChangeConfigurationSupport_shouldBuildResponse) +{ + nlohmann::json initData = R"({ + "params": { + "capabilities": { + "workspace": { + "didChangeConfiguration": {"dynamicRegistration": true} + } + }, + "initializationOptions": { + "configuration": { + } + } + }, + "id": "1" + })"_json; + + nlohmann::json expectedResponse = R"({ + "id": "1", + "method": "client/registerCapability", + "params": { + "registrations": [{ + "id": "12345678", + "method": "workspace/didChangeConfiguration" + }] + } + })"_json; + + handler->OnInitialize(initData); + handler->GetNextResponse(); + + EXPECT_CALL(*mockGenerator, GenerateID()).Times(1); + + handler->OnInitialized(R"({"id": "1"})"_json); + auto response = handler->GetNextResponse(); + + EXPECT_TRUE(response.has_value()); + EXPECT_EQ(*response, expectedResponse); +} + +TEST_F(LSPTest, OnInitialized_withoutDidChangeConfigurationSupport_shouldNotBuildResponse) +{ + nlohmann::json initData = R"({ + "params": { + "capabilities": { + "workspace": { + "didChangeConfiguration": {"dynamicRegistration": false} + } + }, + "initializationOptions": { + "configuration": { + } + } + }, + "id": "1" + })"_json; + + handler->OnInitialize(initData); + handler->GetNextResponse(); + + handler->OnInitialized(R"({"id": "1"})"_json); + auto response = handler->GetNextResponse(); + + EXPECT_FALSE(response.has_value()); +} diff --git a/tests/mocks/generator-mock.hpp b/tests/mocks/generator-mock.hpp new file mode 100644 index 0000000..b3680a2 --- /dev/null +++ b/tests/mocks/generator-mock.hpp @@ -0,0 +1,17 @@ +// +// generator-mock.hpp +// opencl-language-server-tests +// +// Created by Ilia Shoshin on 7/30/23. +// + + +#include "utils.hpp" + +#include + +class GeneratorMock : public ocls::utils::IGenerator +{ +public: + MOCK_METHOD(std::string, GenerateID, (), (override)); +}; From c3c3c2ce56b5e928b51612388046fe25bc35e0c9 Mon Sep 17 00:00:00 2001 From: Galarius Date: Tue, 1 Aug 2023 00:17:05 +0300 Subject: [PATCH 10/32] Add test for LSPServerEventsHandler::BuildDiagnosticsRespond --- include/diagnostics.hpp | 7 ++++- tests/lsp-event-handler-tests.cpp | 50 +++++++++++++++++++++++++++++++ 2 files changed, 56 insertions(+), 1 deletion(-) mode change 100755 => 100644 include/diagnostics.hpp diff --git a/include/diagnostics.hpp b/include/diagnostics.hpp old mode 100755 new mode 100644 index 2f7b8fb..e5eecc6 --- a/include/diagnostics.hpp +++ b/include/diagnostics.hpp @@ -19,12 +19,17 @@ struct Source { std::string filePath; std::string text; + + bool operator==(const Source& other) const + { + return filePath == other.filePath && text == other.text; + } }; struct IDiagnostics { virtual ~IDiagnostics() = default; - + virtual void SetBuildOptions(const nlohmann::json& options) = 0; virtual void SetMaxProblemsCount(uint64_t maxNumberOfProblems) = 0; virtual void SetOpenCLDevice(uint32_t identifier) = 0; diff --git a/tests/lsp-event-handler-tests.cpp b/tests/lsp-event-handler-tests.cpp index c9a7dad..d1546f0 100644 --- a/tests/lsp-event-handler-tests.cpp +++ b/tests/lsp-event-handler-tests.cpp @@ -201,3 +201,53 @@ TEST_F(LSPTest, OnInitialized_withoutDidChangeConfigurationSupport_shouldNotBuil EXPECT_FALSE(response.has_value()); } + +// BuildDiagnosticsRespond + +TEST_F(LSPTest, BuildDiagnosticsRespond_shouldBuildResponse) +{ + std::string uri = "kernel.cl"; + std::string content = + R"(__kernel void add(__global double* a, __global double* b, __global double* c, const unsigned int n) { + int id = get_global_id(0); + if (id < n) { + c[id] = a[id] + b[id]; + } + })"; + nlohmann::json expected_diagnostics; + nlohmann::json diagnostic = { + {"source", uri}, + {"range", + { + {"start", + { + {"line", 1}, + {"character", 1}, + }}, + {"end", + { + {"line", 1}, + {"character", 2}, + }}, + {"severity", 2}, + {"message", "message"}, + }}}; + expected_diagnostics.emplace_back(diagnostic); + Source expectedSource {uri, content}; + nlohmann::json expectedResponse = { + {"method", "textDocument/publishDiagnostics"}, + {"params", + { + {"uri", uri}, + {"diagnostics", expected_diagnostics}, + }}}; + + ON_CALL(*mockDiagnostics, Get(testing::_)).WillByDefault(::testing::Return(expected_diagnostics)); + EXPECT_CALL(*mockDiagnostics, Get(expectedSource)).Times(1); + + handler->BuildDiagnosticsRespond(uri, content); + auto response = handler->GetNextResponse(); + + EXPECT_TRUE(response.has_value()); + EXPECT_EQ(*response, expectedResponse); +} From 03b412697b1afc0f0ba9d3c93e56133ad4883b76 Mon Sep 17 00:00:00 2001 From: Galarius Date: Tue, 1 Aug 2023 23:35:02 +0300 Subject: [PATCH 11/32] Add test for throwing LSPServerEventsHandler::BuildDiagnosticsRespond --- tests/lsp-event-handler-tests.cpp | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tests/lsp-event-handler-tests.cpp b/tests/lsp-event-handler-tests.cpp index d1546f0..e8b596e 100644 --- a/tests/lsp-event-handler-tests.cpp +++ b/tests/lsp-event-handler-tests.cpp @@ -251,3 +251,16 @@ TEST_F(LSPTest, BuildDiagnosticsRespond_shouldBuildResponse) EXPECT_TRUE(response.has_value()); EXPECT_EQ(*response, expectedResponse); } + +TEST_F(LSPTest, BuildDiagnosticsRespond_withException_shouldReplyWithError) +{ + std::string uri = "kernel.cl"; + std::string content = "__kernel void add() {}"; + Source expectedSource {uri, content}; + + ON_CALL(*mockDiagnostics, Get(testing::_)).WillByDefault(::testing::Throw(std::runtime_error("Exception"))); + EXPECT_CALL(*mockDiagnostics, Get(expectedSource)).Times(1); + EXPECT_CALL(*mockJsonRPC, WriteError(JRPCErrorCode::InternalError, "Failed to get diagnostics: Exception")).Times(1); + + handler->BuildDiagnosticsRespond(uri, content); +} From aeffd691e938f373076ac28ddea89e42a7d14047 Mon Sep 17 00:00:00 2001 From: Galarius Date: Thu, 3 Aug 2023 21:20:42 +0300 Subject: [PATCH 12/32] Add test for LSPServerEventsHandler::OnTextOpen and OnTextChanged --- src/lsp.cpp | 24 +++-- tests/lsp-event-handler-tests.cpp | 142 +++++++++++++++++++++--------- 2 files changed, 119 insertions(+), 47 deletions(-) diff --git a/src/lsp.cpp b/src/lsp.cpp index 845d49c..b7883ee 100644 --- a/src/lsp.cpp +++ b/src/lsp.cpp @@ -256,18 +256,28 @@ void LSPServerEventsHandler::BuildDiagnosticsRespond(const std::string &uri, con void LSPServerEventsHandler::OnTextOpen(const json &data) { spdlog::get(logger)->debug("Received 'textOpen' message"); - std::string srcUri = data["params"]["textDocument"]["uri"].get(); - std::string content = data["params"]["textDocument"]["text"].get(); - BuildDiagnosticsRespond(srcUri, content); + auto uri = GetNestedValue(data, {"params", "textDocument", "uri"}); + auto content = GetNestedValue(data, {"params", "textDocument", "text"}); + if (uri && content) { + BuildDiagnosticsRespond(uri->get(), content->get()); + } + } void LSPServerEventsHandler::OnTextChanged(const json &data) { spdlog::get(logger)->debug("Received 'textChanged' message"); - std::string srcUri = data["params"]["textDocument"]["uri"].get(); - std::string content = data["params"]["contentChanges"][0]["text"].get(); - - BuildDiagnosticsRespond(srcUri, content); + auto uri = GetNestedValue(data, {"params", "textDocument", "uri"}); + auto contentChanges = GetNestedValue(data, {"params", "contentChanges" }); + if(contentChanges && contentChanges->size() > 0) { + // Only one content change with the full content of the document is supported. + auto lastIdx = contentChanges->size() - 1; + auto lastContent = (*contentChanges)[lastIdx]; + if(lastContent.contains("text")) { + auto text = lastContent["text"].get(); + BuildDiagnosticsRespond(uri->get(), text); + } + } } void LSPServerEventsHandler::OnConfiguration(const json &data) diff --git a/tests/lsp-event-handler-tests.cpp b/tests/lsp-event-handler-tests.cpp index e8b596e..7bd583e 100644 --- a/tests/lsp-event-handler-tests.cpp +++ b/tests/lsp-event-handler-tests.cpp @@ -13,6 +13,7 @@ #include #include +#include using namespace ocls; using namespace nlohmann; @@ -37,6 +38,46 @@ class LSPTest : public ::testing::Test handler = CreateLSPEventsHandler(mockJsonRPC, mockDiagnostics, mockGenerator); } + + std::tuple GetTestSource() const { + std::string uri = "kernel.cl"; + std::string content = + R"(__kernel void add(__global double* a, __global double* b, __global double* c, const unsigned int n) { + int id = get_global_id(0); + if (id < n) { + c[id] = a[id] + b[id]; + } + })"; + return std::make_tuple(uri, content); + } + + nlohmann::json GetTestDiagnostics(const std::string& uri) const { + return {{ + {"source", uri}, + {"range", { + {"start", { + {"line", 1}, + {"character", 1}, + }}, + {"end", { + {"line", 1}, + {"character", 2}, + }}, + {"severity", 2}, + {"message", "message"}, + }} + }}; + } + + nlohmann::json GetTestDiagnosticsResponse(const std::string& uri) const { + return { + {"method", "textDocument/publishDiagnostics"}, + {"params", + { + {"uri", uri}, + {"diagnostics", GetTestDiagnostics(uri)}, + }}}; + } }; } // namespace @@ -206,44 +247,12 @@ TEST_F(LSPTest, OnInitialized_withoutDidChangeConfigurationSupport_shouldNotBuil TEST_F(LSPTest, BuildDiagnosticsRespond_shouldBuildResponse) { - std::string uri = "kernel.cl"; - std::string content = - R"(__kernel void add(__global double* a, __global double* b, __global double* c, const unsigned int n) { - int id = get_global_id(0); - if (id < n) { - c[id] = a[id] + b[id]; - } - })"; - nlohmann::json expected_diagnostics; - nlohmann::json diagnostic = { - {"source", uri}, - {"range", - { - {"start", - { - {"line", 1}, - {"character", 1}, - }}, - {"end", - { - {"line", 1}, - {"character", 2}, - }}, - {"severity", 2}, - {"message", "message"}, - }}}; - expected_diagnostics.emplace_back(diagnostic); - Source expectedSource {uri, content}; - nlohmann::json expectedResponse = { - {"method", "textDocument/publishDiagnostics"}, - {"params", - { - {"uri", uri}, - {"diagnostics", expected_diagnostics}, - }}}; - - ON_CALL(*mockDiagnostics, Get(testing::_)).WillByDefault(::testing::Return(expected_diagnostics)); - EXPECT_CALL(*mockDiagnostics, Get(expectedSource)).Times(1); + auto [uri, content] = GetTestSource(); + auto expectedDiagnostics = GetTestDiagnostics(uri); + auto expectedResponse = GetTestDiagnosticsResponse(uri); + + ON_CALL(*mockDiagnostics, Get(testing::_)).WillByDefault(::testing::Return(expectedDiagnostics)); + EXPECT_CALL(*mockDiagnostics, Get(Source{uri, content})).Times(1); handler->BuildDiagnosticsRespond(uri, content); auto response = handler->GetNextResponse(); @@ -254,8 +263,7 @@ TEST_F(LSPTest, BuildDiagnosticsRespond_shouldBuildResponse) TEST_F(LSPTest, BuildDiagnosticsRespond_withException_shouldReplyWithError) { - std::string uri = "kernel.cl"; - std::string content = "__kernel void add() {}"; + auto [uri, content] = GetTestSource(); Source expectedSource {uri, content}; ON_CALL(*mockDiagnostics, Get(testing::_)).WillByDefault(::testing::Throw(std::runtime_error("Exception"))); @@ -264,3 +272,57 @@ TEST_F(LSPTest, BuildDiagnosticsRespond_withException_shouldReplyWithError) handler->BuildDiagnosticsRespond(uri, content); } + +// OnTextOpen + +TEST_F(LSPTest, OnTextOpen_shouldBuildResponse) +{ + auto [uri, content] = GetTestSource(); + auto expectedDiagnostics = GetTestDiagnostics(uri); + auto expectedResponse = GetTestDiagnosticsResponse(uri); + nlohmann::json request = { + {"params", { + {"textDocument", { + {"uri", uri}, + {"text", content} + }} + }} + }; + + ON_CALL(*mockDiagnostics, Get(testing::_)).WillByDefault(::testing::Return(expectedDiagnostics)); + EXPECT_CALL(*mockDiagnostics, Get(Source {uri, content})).Times(1); + + handler->OnTextOpen(request); + auto response = handler->GetNextResponse(); + + EXPECT_TRUE(response.has_value()); + EXPECT_EQ(*response, expectedResponse); +} + +// OnTextChanged + +TEST_F(LSPTest, OnTextChanged_shouldBuildResponse) +{ + auto [uri, content] = GetTestSource(); + auto expectedDiagnostics = GetTestDiagnostics(uri); + auto expectedResponse = GetTestDiagnosticsResponse(uri); + nlohmann::json request = { + {"params", { + {"textDocument", { + {"uri", uri}, + }}, + {"contentChanges", {{ + {"text", content} + }}} + }} + }; + + ON_CALL(*mockDiagnostics, Get(testing::_)).WillByDefault(::testing::Return(expectedDiagnostics)); + EXPECT_CALL(*mockDiagnostics, Get(Source {uri, content})).Times(1); + + handler->OnTextChanged(request); + auto response = handler->GetNextResponse(); + + EXPECT_TRUE(response.has_value()); + EXPECT_EQ(*response, expectedResponse); +} From d2a79a7ba03fbbd2b9381c8423c4758bab6be7aa Mon Sep 17 00:00:00 2001 From: Galarius Date: Thu, 3 Aug 2023 21:46:54 +0300 Subject: [PATCH 13/32] Update LSPServerEventsHandler::OnConfiguration and add test --- src/lsp.cpp | 55 ++++++++++++++++++++----------- tests/lsp-event-handler-tests.cpp | 22 ++++++++++++- 2 files changed, 56 insertions(+), 21 deletions(-) diff --git a/src/lsp.cpp b/src/lsp.cpp index b7883ee..c132d5c 100644 --- a/src/lsp.cpp +++ b/src/lsp.cpp @@ -18,6 +18,11 @@ using namespace nlohmann; namespace { +constexpr int BuildOptions = 0; +constexpr int MaxProblemsCount = 1; +constexpr int DeviceID = 2; +constexpr int NumConfigurations = 2; + std::optional GetNestedValue(const nlohmann::json &j, const std::vector &keys) { const nlohmann::json *current = &j; @@ -282,34 +287,44 @@ void LSPServerEventsHandler::OnTextChanged(const json &data) void LSPServerEventsHandler::OnConfiguration(const json &data) { - spdlog::get(logger)->debug("Received 'configuration' respond"); - auto result = data["result"]; - if (result.empty()) + auto log = spdlog::get(logger); + log->debug("Received 'configuration' respond"); + + try { - spdlog::get(logger)->warn("Empty result"); - return; - } + auto result = data.at("result"); + if (result.empty()) + { + log->warn("Empty configuration"); + return; + } - if (result.size() != 3) - { - spdlog::get(logger)->warn("Unexpected result items count"); - return; - } + if (result.size() < NumConfigurations) + { + log->warn("Unexpected number of options"); + return; + } - try - { - auto buildOptions = result[0]; - m_diagnostics->SetBuildOptions(buildOptions); + if (result[BuildOptions].is_array()) + { + m_diagnostics->SetBuildOptions(result[BuildOptions]); + } - auto maxProblemsCount = result[1].get(); - m_diagnostics->SetMaxProblemsCount(static_cast(maxProblemsCount)); + if (result[MaxProblemsCount].is_number_integer()) + { + auto maxProblemsCount = result[MaxProblemsCount].get(); + m_diagnostics->SetMaxProblemsCount(static_cast(maxProblemsCount)); + } - auto deviceID = result[2].get(); - m_diagnostics->SetOpenCLDevice(static_cast(deviceID)); + if (result[DeviceID].is_number_integer()) + { + auto deviceID = result[DeviceID].get(); + m_diagnostics->SetOpenCLDevice(static_cast(deviceID)); + } } catch (std::exception &err) { - spdlog::get(logger)->error("Failed to update settings, {}", err.what()); + log->error("Failed to update settings, {}", err.what()); } } diff --git a/tests/lsp-event-handler-tests.cpp b/tests/lsp-event-handler-tests.cpp index 7bd583e..b2321db 100644 --- a/tests/lsp-event-handler-tests.cpp +++ b/tests/lsp-event-handler-tests.cpp @@ -97,7 +97,7 @@ TEST_F(LSPTest, OnInitialize_shouldBuildResponse_andCallDiagnosticsSetters) }, "initializationOptions": { "configuration": { - "buildOptions": { "option": "value" }, + "buildOptions": [ "-I", "/usr/local/include" ], "maxNumberOfProblems": 10, "deviceID": 1 } @@ -326,3 +326,23 @@ TEST_F(LSPTest, OnTextChanged_shouldBuildResponse) EXPECT_TRUE(response.has_value()); EXPECT_EQ(*response, expectedResponse); } + +// OnConfiguration + +TEST_F(LSPTest, OnConfiguration_shouldUpdateSettings) { + nlohmann::json data = R"({ + "result": [ + ["-I", "/usr/local/include"], + 100, + 1 + ] + })"_json; + + // Set up expectations for the diagnostics object + EXPECT_CALL(*mockDiagnostics, SetBuildOptions(R"(["-I", "/usr/local/include"])"_json)).Times(1); + EXPECT_CALL(*mockDiagnostics, SetMaxProblemsCount(100)).Times(1); + EXPECT_CALL(*mockDiagnostics, SetOpenCLDevice(1)).Times(1); + + // Call the function under test + handler->OnConfiguration(data); +} From a77cb9589d1c1113750ef03bc9140f4cb1d63546 Mon Sep 17 00:00:00 2001 From: Galarius Date: Thu, 3 Aug 2023 22:19:26 +0300 Subject: [PATCH 14/32] Add test for LSPServerEventsHandler::GetConfiguration and OnRespond --- src/lsp.cpp | 22 +++++++- tests/lsp-event-handler-tests.cpp | 93 ++++++++++++++++++++++++++++++- 2 files changed, 110 insertions(+), 5 deletions(-) diff --git a/src/lsp.cpp b/src/lsp.cpp index c132d5c..ff0a571 100644 --- a/src/lsp.cpp +++ b/src/lsp.cpp @@ -331,14 +331,30 @@ void LSPServerEventsHandler::OnConfiguration(const json &data) void LSPServerEventsHandler::OnRespond(const json &data) { spdlog::get(logger)->debug("Received client respond"); - const auto id = data["id"]; - if (!m_requests.empty()) + if (m_requests.empty()) { + spdlog::get(logger)->warn("Unexpected respond {}", data.dump()); + return; + } + + try { + const auto id = data["id"]; auto request = m_requests.front(); - if (id == request.second && request.first == "workspace/configuration") + if (id == request.second && + "workspace/configuration" == request.first) + { OnConfiguration(data); + } + else + { + spdlog::get(logger)->warn("Out of order respond"); + } m_requests.pop(); } + catch (std::exception &err) + { + spdlog::get(logger)->error("OnRespond failed, {}", err.what()); + } } void LSPServerEventsHandler::OnShutdown(const json &data) diff --git a/tests/lsp-event-handler-tests.cpp b/tests/lsp-event-handler-tests.cpp index b2321db..744bf61 100644 --- a/tests/lsp-event-handler-tests.cpp +++ b/tests/lsp-event-handler-tests.cpp @@ -338,11 +338,100 @@ TEST_F(LSPTest, OnConfiguration_shouldUpdateSettings) { ] })"_json; - // Set up expectations for the diagnostics object EXPECT_CALL(*mockDiagnostics, SetBuildOptions(R"(["-I", "/usr/local/include"])"_json)).Times(1); EXPECT_CALL(*mockDiagnostics, SetMaxProblemsCount(100)).Times(1); EXPECT_CALL(*mockDiagnostics, SetOpenCLDevice(1)).Times(1); - // Call the function under test handler->OnConfiguration(data); } + +// GetConfiguration + +TEST_F(LSPTest, GetConfiguration_whenHasConfigurationCapabilityNotSet_shouldDoNothing) { + handler->GetConfiguration(); + EXPECT_FALSE(handler->GetNextResponse().has_value()); +} + +TEST_F(LSPTest, GetConfiguration_whenHasConfigurationCapability_shouldBuildResponse) { + nlohmann::json initialData = R"({ + "params": { + "capabilities": { + "workspace": { + "configuration": true, + "didChangeConfiguration": {"dynamicRegistration": true} + } + }, + "initializationOptions": { + "configuration": { + + } + } + }, + "id": "1" + })"_json; + + nlohmann::json expectedResponse = R"({ + "id": "12345678", + "method": "workspace/configuration", + "params": { + "items": [ + {"section": "OpenCL.server.buildOptions"}, + {"section": "OpenCL.server.maxNumberOfProblems"}, + {"section": "OpenCL.server.deviceID"} + ] + } + })"_json; + + EXPECT_CALL(*mockGenerator, GenerateID()).Times(1); + + handler->OnInitialize(initialData); + handler->GetNextResponse(); + handler->GetConfiguration(); + + auto response = handler->GetNextResponse(); + EXPECT_TRUE(response.has_value()); + EXPECT_EQ(*response, expectedResponse); +} + +// OnRespond + +TEST_F(LSPTest, OnRespond_whenConfigurationRespond_shouldUpdateSettings) { + + nlohmann::json initialData = R"({ + "params": { + "capabilities": { + "workspace": { + "configuration": true, + "didChangeConfiguration": {"dynamicRegistration": true} + } + }, + "initializationOptions": { + "configuration": { + + } + } + }, + "id": "1" + })"_json; + + nlohmann::json data = R"({ + "id": "12345678", + "result": [ + ["-I", "/usr/local/include"], + 100, + 1 + ] + })"_json; + + EXPECT_CALL(*mockGenerator, GenerateID()).Times(1); + EXPECT_CALL(*mockDiagnostics, SetBuildOptions(R"(["-I", "/usr/local/include"])"_json)).Times(1); + EXPECT_CALL(*mockDiagnostics, SetMaxProblemsCount(100)).Times(1); + EXPECT_CALL(*mockDiagnostics, SetOpenCLDevice(1)).Times(1); + + handler->OnInitialize(initialData); + handler->GetNextResponse(); + handler->GetConfiguration(); + handler->GetNextResponse(); + handler->OnRespond(data); + handler->OnRespond(data); // the second call shouldn't trigger settings update +} From 2b9778b2348f969a319b34a777ed5aaf62b01875 Mon Sep 17 00:00:00 2001 From: Galarius Date: Thu, 3 Aug 2023 22:20:14 +0300 Subject: [PATCH 15/32] Format lsp code --- include/lsp.hpp | 7 +- src/lsp.cpp | 22 +++--- tests/lsp-event-handler-tests.cpp | 109 +++++++++++++++--------------- 3 files changed, 70 insertions(+), 68 deletions(-) diff --git a/include/lsp.hpp b/include/lsp.hpp index 3bee0a4..e9d4a29 100644 --- a/include/lsp.hpp +++ b/include/lsp.hpp @@ -42,12 +42,13 @@ struct ILSPServerEventsHandler }; std::shared_ptr CreateLSPEventsHandler( - std::shared_ptr jrpc, std::shared_ptr diagnostics, std::shared_ptr generator); + std::shared_ptr jrpc, + std::shared_ptr diagnostics, + std::shared_ptr generator); std::shared_ptr CreateLSPServer( std::shared_ptr jrpc, std::shared_ptr handler); -std::shared_ptr CreateLSPServer( - std::shared_ptr jrpc, std::shared_ptr diagnostics); +std::shared_ptr CreateLSPServer(std::shared_ptr jrpc, std::shared_ptr diagnostics); } // namespace ocls diff --git a/src/lsp.cpp b/src/lsp.cpp index ff0a571..f99d3cf 100644 --- a/src/lsp.cpp +++ b/src/lsp.cpp @@ -263,22 +263,24 @@ void LSPServerEventsHandler::OnTextOpen(const json &data) spdlog::get(logger)->debug("Received 'textOpen' message"); auto uri = GetNestedValue(data, {"params", "textDocument", "uri"}); auto content = GetNestedValue(data, {"params", "textDocument", "text"}); - if (uri && content) { + if (uri && content) + { BuildDiagnosticsRespond(uri->get(), content->get()); } - } void LSPServerEventsHandler::OnTextChanged(const json &data) { spdlog::get(logger)->debug("Received 'textChanged' message"); auto uri = GetNestedValue(data, {"params", "textDocument", "uri"}); - auto contentChanges = GetNestedValue(data, {"params", "contentChanges" }); - if(contentChanges && contentChanges->size() > 0) { + auto contentChanges = GetNestedValue(data, {"params", "contentChanges"}); + if (contentChanges && contentChanges->size() > 0) + { // Only one content change with the full content of the document is supported. auto lastIdx = contentChanges->size() - 1; auto lastContent = (*contentChanges)[lastIdx]; - if(lastContent.contains("text")) { + if (lastContent.contains("text")) + { auto text = lastContent["text"].get(); BuildDiagnosticsRespond(uri->get(), text); } @@ -289,7 +291,7 @@ void LSPServerEventsHandler::OnConfiguration(const json &data) { auto log = spdlog::get(logger); log->debug("Received 'configuration' respond"); - + try { auto result = data.at("result"); @@ -331,17 +333,17 @@ void LSPServerEventsHandler::OnConfiguration(const json &data) void LSPServerEventsHandler::OnRespond(const json &data) { spdlog::get(logger)->debug("Received client respond"); - if (m_requests.empty()) { + if (m_requests.empty()) + { spdlog::get(logger)->warn("Unexpected respond {}", data.dump()); return; } - + try { const auto id = data["id"]; auto request = m_requests.front(); - if (id == request.second && - "workspace/configuration" == request.first) + if (id == request.second && "workspace/configuration" == request.first) { OnConfiguration(data); } diff --git a/tests/lsp-event-handler-tests.cpp b/tests/lsp-event-handler-tests.cpp index 744bf61..e6cd258 100644 --- a/tests/lsp-event-handler-tests.cpp +++ b/tests/lsp-event-handler-tests.cpp @@ -38,8 +38,9 @@ class LSPTest : public ::testing::Test handler = CreateLSPEventsHandler(mockJsonRPC, mockDiagnostics, mockGenerator); } - - std::tuple GetTestSource() const { + + std::tuple GetTestSource() const + { std::string uri = "kernel.cl"; std::string content = R"(__kernel void add(__global double* a, __global double* b, __global double* c, const unsigned int n) { @@ -50,26 +51,30 @@ class LSPTest : public ::testing::Test })"; return std::make_tuple(uri, content); } - - nlohmann::json GetTestDiagnostics(const std::string& uri) const { - return {{ - {"source", uri}, - {"range", { - {"start", { - {"line", 1}, - {"character", 1}, - }}, - {"end", { - {"line", 1}, - {"character", 2}, - }}, - {"severity", 2}, - {"message", "message"}, - }} - }}; + + nlohmann::json GetTestDiagnostics(const std::string& uri) const + { + return { + {{"source", uri}, + {"range", + { + {"start", + { + {"line", 1}, + {"character", 1}, + }}, + {"end", + { + {"line", 1}, + {"character", 2}, + }}, + {"severity", 2}, + {"message", "message"}, + }}}}; } - - nlohmann::json GetTestDiagnosticsResponse(const std::string& uri) const { + + nlohmann::json GetTestDiagnosticsResponse(const std::string& uri) const + { return { {"method", "textDocument/publishDiagnostics"}, {"params", @@ -252,7 +257,7 @@ TEST_F(LSPTest, BuildDiagnosticsRespond_shouldBuildResponse) auto expectedResponse = GetTestDiagnosticsResponse(uri); ON_CALL(*mockDiagnostics, Get(testing::_)).WillByDefault(::testing::Return(expectedDiagnostics)); - EXPECT_CALL(*mockDiagnostics, Get(Source{uri, content})).Times(1); + EXPECT_CALL(*mockDiagnostics, Get(Source {uri, content})).Times(1); handler->BuildDiagnosticsRespond(uri, content); auto response = handler->GetNextResponse(); @@ -265,10 +270,11 @@ TEST_F(LSPTest, BuildDiagnosticsRespond_withException_shouldReplyWithError) { auto [uri, content] = GetTestSource(); Source expectedSource {uri, content}; - + ON_CALL(*mockDiagnostics, Get(testing::_)).WillByDefault(::testing::Throw(std::runtime_error("Exception"))); EXPECT_CALL(*mockDiagnostics, Get(expectedSource)).Times(1); - EXPECT_CALL(*mockJsonRPC, WriteError(JRPCErrorCode::InternalError, "Failed to get diagnostics: Exception")).Times(1); + EXPECT_CALL(*mockJsonRPC, WriteError(JRPCErrorCode::InternalError, "Failed to get diagnostics: Exception")) + .Times(1); handler->BuildDiagnosticsRespond(uri, content); } @@ -280,15 +286,8 @@ TEST_F(LSPTest, OnTextOpen_shouldBuildResponse) auto [uri, content] = GetTestSource(); auto expectedDiagnostics = GetTestDiagnostics(uri); auto expectedResponse = GetTestDiagnosticsResponse(uri); - nlohmann::json request = { - {"params", { - {"textDocument", { - {"uri", uri}, - {"text", content} - }} - }} - }; - + nlohmann::json request = {{"params", {{"textDocument", {{"uri", uri}, {"text", content}}}}}}; + ON_CALL(*mockDiagnostics, Get(testing::_)).WillByDefault(::testing::Return(expectedDiagnostics)); EXPECT_CALL(*mockDiagnostics, Get(Source {uri, content})).Times(1); @@ -307,16 +306,13 @@ TEST_F(LSPTest, OnTextChanged_shouldBuildResponse) auto expectedDiagnostics = GetTestDiagnostics(uri); auto expectedResponse = GetTestDiagnosticsResponse(uri); nlohmann::json request = { - {"params", { - {"textDocument", { - {"uri", uri}, - }}, - {"contentChanges", {{ - {"text", content} - }}} - }} - }; - + {"params", + {{"textDocument", + { + {"uri", uri}, + }}, + {"contentChanges", {{{"text", content}}}}}}}; + ON_CALL(*mockDiagnostics, Get(testing::_)).WillByDefault(::testing::Return(expectedDiagnostics)); EXPECT_CALL(*mockDiagnostics, Get(Source {uri, content})).Times(1); @@ -329,7 +325,8 @@ TEST_F(LSPTest, OnTextChanged_shouldBuildResponse) // OnConfiguration -TEST_F(LSPTest, OnConfiguration_shouldUpdateSettings) { +TEST_F(LSPTest, OnConfiguration_shouldUpdateSettings) +{ nlohmann::json data = R"({ "result": [ ["-I", "/usr/local/include"], @@ -347,12 +344,14 @@ TEST_F(LSPTest, OnConfiguration_shouldUpdateSettings) { // GetConfiguration -TEST_F(LSPTest, GetConfiguration_whenHasConfigurationCapabilityNotSet_shouldDoNothing) { +TEST_F(LSPTest, GetConfiguration_whenHasConfigurationCapabilityNotSet_shouldDoNothing) +{ handler->GetConfiguration(); EXPECT_FALSE(handler->GetNextResponse().has_value()); } -TEST_F(LSPTest, GetConfiguration_whenHasConfigurationCapability_shouldBuildResponse) { +TEST_F(LSPTest, GetConfiguration_whenHasConfigurationCapability_shouldBuildResponse) +{ nlohmann::json initialData = R"({ "params": { "capabilities": { @@ -369,7 +368,7 @@ TEST_F(LSPTest, GetConfiguration_whenHasConfigurationCapability_shouldBuildRespo }, "id": "1" })"_json; - + nlohmann::json expectedResponse = R"({ "id": "12345678", "method": "workspace/configuration", @@ -381,13 +380,13 @@ TEST_F(LSPTest, GetConfiguration_whenHasConfigurationCapability_shouldBuildRespo ] } })"_json; - + EXPECT_CALL(*mockGenerator, GenerateID()).Times(1); - + handler->OnInitialize(initialData); handler->GetNextResponse(); handler->GetConfiguration(); - + auto response = handler->GetNextResponse(); EXPECT_TRUE(response.has_value()); EXPECT_EQ(*response, expectedResponse); @@ -395,8 +394,8 @@ TEST_F(LSPTest, GetConfiguration_whenHasConfigurationCapability_shouldBuildRespo // OnRespond -TEST_F(LSPTest, OnRespond_whenConfigurationRespond_shouldUpdateSettings) { - +TEST_F(LSPTest, OnRespond_whenConfigurationRespond_shouldUpdateSettings) +{ nlohmann::json initialData = R"({ "params": { "capabilities": { @@ -413,7 +412,7 @@ TEST_F(LSPTest, OnRespond_whenConfigurationRespond_shouldUpdateSettings) { }, "id": "1" })"_json; - + nlohmann::json data = R"({ "id": "12345678", "result": [ @@ -427,11 +426,11 @@ TEST_F(LSPTest, OnRespond_whenConfigurationRespond_shouldUpdateSettings) { EXPECT_CALL(*mockDiagnostics, SetBuildOptions(R"(["-I", "/usr/local/include"])"_json)).Times(1); EXPECT_CALL(*mockDiagnostics, SetMaxProblemsCount(100)).Times(1); EXPECT_CALL(*mockDiagnostics, SetOpenCLDevice(1)).Times(1); - + handler->OnInitialize(initialData); handler->GetNextResponse(); handler->GetConfiguration(); handler->GetNextResponse(); handler->OnRespond(data); - handler->OnRespond(data); // the second call shouldn't trigger settings update + handler->OnRespond(data); // the second call shouldn't trigger settings update } From c58aac40c9353c429b794cb5b633e90bc00397bb Mon Sep 17 00:00:00 2001 From: Galarius Date: Thu, 3 Aug 2023 22:36:57 +0300 Subject: [PATCH 16/32] Add test for LSPServerEventsHandler::OnShutdown and OnExit --- include/lsp.hpp | 3 ++- include/utils.hpp | 9 +++++++ src/lsp.cpp | 21 +++++++++++----- src/utils.cpp | 18 ++++++++++++- tests/lsp-event-handler-tests.cpp | 42 ++++++++++++++++++++++++++++++- tests/mocks/exit-handler-mock.hpp | 17 +++++++++++++ 6 files changed, 101 insertions(+), 9 deletions(-) create mode 100644 tests/mocks/exit-handler-mock.hpp diff --git a/include/lsp.hpp b/include/lsp.hpp index e9d4a29..c0cbbaf 100644 --- a/include/lsp.hpp +++ b/include/lsp.hpp @@ -44,7 +44,8 @@ struct ILSPServerEventsHandler std::shared_ptr CreateLSPEventsHandler( std::shared_ptr jrpc, std::shared_ptr diagnostics, - std::shared_ptr generator); + std::shared_ptr generator, + std::shared_ptr exitHandler); std::shared_ptr CreateLSPServer( std::shared_ptr jrpc, std::shared_ptr handler); diff --git a/include/utils.hpp b/include/utils.hpp index 5563bff..7e29a79 100644 --- a/include/utils.hpp +++ b/include/utils.hpp @@ -23,8 +23,17 @@ struct IGenerator virtual std::string GenerateID() = 0; }; +struct IExitHandler +{ + virtual ~IExitHandler() = default; + + virtual void OnExit(int code) = 0; +}; + std::shared_ptr CreateDefaultGenerator(); +std::shared_ptr CreateDefaultExitHandler(); + void Trim(std::string& s); std::vector SplitString(const std::string& str, const std::string& pattern); std::string UriToPath(const std::string& uri); diff --git a/src/lsp.cpp b/src/lsp.cpp index f99d3cf..4ad2423 100644 --- a/src/lsp.cpp +++ b/src/lsp.cpp @@ -86,10 +86,12 @@ class LSPServerEventsHandler final : public ILSPServerEventsHandler LSPServerEventsHandler( std::shared_ptr jrpc, std::shared_ptr diagnostics, - std::shared_ptr generator) + std::shared_ptr generator, + std::shared_ptr exitHandler) : m_jrpc {std::move(jrpc)} , m_diagnostics {std::move(diagnostics)} , m_generator {std::move(generator)} + , m_exitHandler {std::move(exitHandler)} {} void BuildDiagnosticsRespond(const std::string &uri, const std::string &content); @@ -108,6 +110,7 @@ class LSPServerEventsHandler final : public ILSPServerEventsHandler std::shared_ptr m_jrpc; std::shared_ptr m_diagnostics; std::shared_ptr m_generator; + std::shared_ptr m_exitHandler; std::queue m_outQueue; Capabilities m_capabilities; std::queue> m_requests; @@ -370,9 +373,13 @@ void LSPServerEventsHandler::OnExit() { spdlog::get(logger)->debug("Received 'exit', after 'shutdown': {}", m_shutdown ? "yes" : "no"); if (m_shutdown) - exit(EXIT_SUCCESS); + { + m_exitHandler->OnExit(EXIT_SUCCESS); + } else - exit(EXIT_FAILURE); + { + m_exitHandler->OnExit(EXIT_FAILURE); + } } // ILSPServer @@ -462,9 +469,10 @@ void LSPServer::Interrupt() std::shared_ptr CreateLSPEventsHandler( std::shared_ptr jrpc, std::shared_ptr diagnostics, - std::shared_ptr generator) + std::shared_ptr generator, + std::shared_ptr exitHandler) { - return std::make_shared(jrpc, diagnostics, generator); + return std::make_shared(jrpc, diagnostics, generator, exitHandler); } std::shared_ptr CreateLSPServer( @@ -476,7 +484,8 @@ std::shared_ptr CreateLSPServer( std::shared_ptr CreateLSPServer(std::shared_ptr jrpc, std::shared_ptr diagnostics) { auto generator = utils::CreateDefaultGenerator(); - auto handler = std::make_shared(jrpc, diagnostics, generator); + auto exitHandler = utils::CreateDefaultExitHandler(); + auto handler = std::make_shared(jrpc, diagnostics, generator, exitHandler); return std::make_shared(std::move(jrpc), std::move(handler)); } diff --git a/src/utils.cpp b/src/utils.cpp index c3bccb9..44190e8 100644 --- a/src/utils.cpp +++ b/src/utils.cpp @@ -39,12 +39,28 @@ class DefaultGenerator final : public IGenerator std::uniform_int_distribution<> dis; }; - std::shared_ptr CreateDefaultGenerator() { return std::make_shared(); } +class DefaultExitHandler final : public IExitHandler +{ +public: + DefaultExitHandler() = default; + + void OnExit(int code) + { + exit(code); + } +}; + +std::shared_ptr CreateDefaultExitHandler() +{ + return std::make_shared(); +} + + void Trim(std::string& s) { s.erase(s.begin(), std::find_if(s.begin(), s.end(), [](unsigned char ch) { return !std::isspace(ch); })); diff --git a/tests/lsp-event-handler-tests.cpp b/tests/lsp-event-handler-tests.cpp index e6cd258..09e7c0b 100644 --- a/tests/lsp-event-handler-tests.cpp +++ b/tests/lsp-event-handler-tests.cpp @@ -7,6 +7,7 @@ #include "lsp.hpp" #include "utils.hpp" +#include "exit-handler-mock.hpp" #include "jsonrpc-mock.hpp" #include "diagnostics-mock.hpp" #include "generator-mock.hpp" @@ -26,6 +27,7 @@ class LSPTest : public ::testing::Test std::shared_ptr mockJsonRPC; std::shared_ptr mockDiagnostics; std::shared_ptr mockGenerator; + std::shared_ptr mockExitHandler; std::shared_ptr handler; void SetUp() override @@ -33,10 +35,11 @@ class LSPTest : public ::testing::Test mockJsonRPC = std::make_shared(); mockDiagnostics = std::make_shared(); mockGenerator = std::make_shared(); + mockExitHandler = std::make_shared(); ON_CALL(*mockGenerator, GenerateID()).WillByDefault(::testing::Return("12345678")); - handler = CreateLSPEventsHandler(mockJsonRPC, mockDiagnostics, mockGenerator); + handler = CreateLSPEventsHandler(mockJsonRPC, mockDiagnostics, mockGenerator, mockExitHandler); } std::tuple GetTestSource() const @@ -434,3 +437,40 @@ TEST_F(LSPTest, OnRespond_whenConfigurationRespond_shouldUpdateSettings) handler->OnRespond(data); handler->OnRespond(data); // the second call shouldn't trigger settings update } + +// OnShutdown + +TEST_F(LSPTest, OnShutdown_shouldBuildResponse) +{ + nlohmann::json data = R"({ + "id": "12345678" + })"_json; + nlohmann::json expectedResponse = R"({ + "id": "12345678", + "result": null + })"_json; + + handler->OnShutdown(data); + + auto response = handler->GetNextResponse(); + EXPECT_TRUE(response.has_value()); + EXPECT_EQ(*response, expectedResponse); +} + +// OnExit + +TEST_F(LSPTest, OnExit_shouldExitWithFailureByDefault) +{ + EXPECT_CALL(*mockExitHandler, OnExit(EXIT_FAILURE)).Times(1); + handler->OnExit(); +} + +TEST_F(LSPTest, OnExit_shouldExitWithSuccessAfterShutdownCall) +{ + nlohmann::json data = R"({ + "id": "12345678" + })"_json; + handler->OnShutdown(data); + EXPECT_CALL(*mockExitHandler, OnExit(EXIT_SUCCESS)).Times(1); + handler->OnExit(); +} diff --git a/tests/mocks/exit-handler-mock.hpp b/tests/mocks/exit-handler-mock.hpp new file mode 100644 index 0000000..8c7de12 --- /dev/null +++ b/tests/mocks/exit-handler-mock.hpp @@ -0,0 +1,17 @@ +// +// exit-handler-mock.hpp +// opencl-language-server-tests +// +// Created by Ilia Shoshin on 8/03/23. +// + + +#include "utils.hpp" + +#include + +class ExitHandlerMock : public ocls::utils::IExitHandler +{ +public: + MOCK_METHOD(void, OnExit, (int), (override)); +}; From 8b23afe3405f56b51a503c55ff2e9269980b9c64 Mon Sep 17 00:00:00 2001 From: Galarius Date: Fri, 4 Aug 2023 21:40:33 +0300 Subject: [PATCH 17/32] Replace --clinfo with subcommand --- src/main.cpp | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/src/main.cpp b/src/main.cpp index 5d18724..bf0ea69 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -91,12 +91,14 @@ inline void SetupBinaryStreamMode() int main(int argc, char* argv[]) { bool flagLogTofile = false; - bool flagCLInfo = false; std::string optLogFile = "opencl-language-server.log"; spdlog::level::level_enum optLogLevel = spdlog::level::trace; - CLI::App app {"OpenCL Language Server"}; - app.add_flag("-i,--clinfo", flagCLInfo, "Show information about available OpenCL devices"); + CLI::App app { + "OpenCL Language Server\n" + "The language server communicates with a client using JSON-RPC protocol.\n" + "You can stop the server by sending an interrupt signal followed by any character sent to standard input.\n" + }; app.add_flag("-e,--enable-file-logging", flagLogTofile, "Enable file logging"); app.add_option("-f,--log-file", optLogFile, "Path to log file")->required(false)->capture_default_str(); app.add_option("-l,--log-level", optLogLevel, "Log level") @@ -116,17 +118,22 @@ int main(int argc, char* argv[]) exit(0); }, "Show version"); + + bool flagPrettyPrint = false; + auto clinfoCmd = app.add_subcommand("clinfo", "Show information about available OpenCL devices"); + clinfoCmd->add_flag("-p,--pretty-print", flagPrettyPrint, "Enable pretty-printing"); CLI11_PARSE(app, argc, argv); ConfigureLogging(flagLogTofile, optLogFile, optLogLevel); auto clinfo = CreateCLInfo(); - if (flagCLInfo) + if (*clinfoCmd) { const auto jsonBody = clinfo->json(); - std::cout << jsonBody.dump() << std::endl; - exit(0); + const int indentation = flagPrettyPrint ? 4 : -1; + std::cout << jsonBody.dump(indentation) << std::endl; + return 0; } SetupBinaryStreamMode(); From 8c794521af515757e70226b0d8b0a4d895f2f202 Mon Sep 17 00:00:00 2001 From: Galarius Date: Tue, 8 Aug 2023 00:14:28 +0300 Subject: [PATCH 18/32] Add 'diagnostics' subcommand --- README.md | 4 +- include/diagnostics.hpp | 5 +- include/utils.hpp | 2 + src/diagnostics.cpp | 36 ++++---- src/lsp.cpp | 2 +- src/main.cpp | 147 +++++++++++++++++++++++++----- src/utils.cpp | 17 ++++ tests/lsp-event-handler-tests.cpp | 16 ++-- tests/mocks/diagnostics-mock.hpp | 6 +- 9 files changed, 185 insertions(+), 50 deletions(-) diff --git a/README.md b/README.md index be9e8dc..db02b8f 100644 --- a/README.md +++ b/README.md @@ -33,10 +33,10 @@ You can configure diagnostics with `json-rpc` request during the intitialization ||| | --- | --- | -| `buildOptions` | Build options to be used for building the program. The list of [supported](https://registry.khronos.org/OpenCL/sdk/2.1/docs/man/xhtml/clBuildProgram.html) build options. | +| `buildOptions` | Options to be utilized when building the program. The list of [supported](https://registry.khronos.org/OpenCL/sdk/2.1/docs/man/xhtml/clBuildProgram.html) build options. | | `deviceID` | Device ID or 0 (automatic selection) of the OpenCL device to be used for diagnostics. | | | *Run `./opencl-language-server --clinfo` to get information about available OpenCL devices including identifiers.* | -| `maxNumberOfProblems` | Controls the maximum number of problems produced by the language server. | +| `maxNumberOfProblems` | Controls the maximum number of errors parsed by the language server. | ## Development diff --git a/include/diagnostics.hpp b/include/diagnostics.hpp index e5eecc6..17fcb89 100644 --- a/include/diagnostics.hpp +++ b/include/diagnostics.hpp @@ -31,9 +31,12 @@ struct IDiagnostics virtual ~IDiagnostics() = default; virtual void SetBuildOptions(const nlohmann::json& options) = 0; + virtual void SetBuildOptions(const std::string& options) = 0; virtual void SetMaxProblemsCount(uint64_t maxNumberOfProblems) = 0; virtual void SetOpenCLDevice(uint32_t identifier) = 0; - virtual nlohmann::json Get(const Source& source) = 0; + + virtual std::string GetBuildLog(const Source& source) = 0; + virtual nlohmann::json GetDiagnostics(const Source& source) = 0; }; std::shared_ptr CreateDiagnostics(std::shared_ptr clInfo); diff --git a/include/utils.hpp b/include/utils.hpp index 7e29a79..330f1d7 100644 --- a/include/utils.hpp +++ b/include/utils.hpp @@ -11,6 +11,7 @@ #include #include #include +#include #include #include @@ -38,6 +39,7 @@ void Trim(std::string& s); std::vector SplitString(const std::string& str, const std::string& pattern); std::string UriToPath(const std::string& uri); bool EndsWith(const std::string& str, const std::string& suffix); +std::optional ReadFileContent(std::string_view fileName); namespace internal { // Generates a lookup table for the checksums of all 8-bit values. diff --git a/src/diagnostics.cpp b/src/diagnostics.cpp index f1bf002..e20b862 100644 --- a/src/diagnostics.cpp +++ b/src/diagnostics.cpp @@ -63,9 +63,11 @@ class Diagnostics final : public IDiagnostics explicit Diagnostics(std::shared_ptr clInfo); void SetBuildOptions(const nlohmann::json& options); + void SetBuildOptions(const std::string& options); void SetMaxProblemsCount(uint64_t maxNumberOfProblems); void SetOpenCLDevice(uint32_t identifier); - nlohmann::json Get(const Source& source); + std::string GetBuildLog(const Source& source); + nlohmann::json GetDiagnostics(const Source& source); private: nlohmann::json BuildDiagnostics(const std::string& buildLog, const std::string& name); @@ -237,26 +239,27 @@ nlohmann::json Diagnostics::BuildDiagnostics(const std::string& buildLog, const return diagnostics; } -nlohmann::json Diagnostics::Get(const Source& source) +std::string Diagnostics::GetBuildLog(const Source& source) { if (!m_device.has_value()) { throw std::runtime_error("missing OpenCL device"); } - spdlog::get(logger)->trace("Getting diagnostics..."); - std::string buildLog; - std::string srcName; + return BuildSource(source.text); +} +nlohmann::json Diagnostics::GetDiagnostics(const Source& source) +{ + std::string buildLog = GetBuildLog(source); + std::string srcName; if (!source.filePath.empty()) { auto filePath = std::filesystem::path(source.filePath).string(); srcName = std::filesystem::path(filePath).filename().string(); } - buildLog = BuildSource(source.text); - spdlog::get(logger)->trace("BuildLog:\n", buildLog); - + spdlog::get(logger)->trace("BuildLog:\n{}", buildLog); return BuildDiagnostics(buildLog, srcName); } @@ -264,14 +267,9 @@ void Diagnostics::SetBuildOptions(const json& options) { try { - std::string args; - for (auto option : options) - { - args.append(option.get()); - args.append(" "); - } - m_BuildOptions = std::move(args); - spdlog::get(logger)->trace("Set build options, {}", m_BuildOptions); + auto concat = [](const std::string& acc, const json& j) { return acc + j.get() + " "; }; + auto opts = std::accumulate(options.begin(), options.end(), std::string(), concat); + SetBuildOptions(opts); } catch (std::exception& e) { @@ -279,6 +277,12 @@ void Diagnostics::SetBuildOptions(const json& options) } } +void Diagnostics::SetBuildOptions(const std::string& options) +{ + m_BuildOptions = options; + spdlog::get(logger)->trace("Set build options, {}", m_BuildOptions); +} + void Diagnostics::SetMaxProblemsCount(uint64_t maxNumberOfProblems) { spdlog::get(logger)->trace("Set max number of problems: {}", maxNumberOfProblems); diff --git a/src/lsp.cpp b/src/lsp.cpp index 4ad2423..d02048e 100644 --- a/src/lsp.cpp +++ b/src/lsp.cpp @@ -244,7 +244,7 @@ void LSPServerEventsHandler::BuildDiagnosticsRespond(const std::string &uri, con const auto filePath = utils::UriToPath(uri); spdlog::get(logger)->debug("Converted uri '{}' to path '{}'", uri, filePath); - json diags = m_diagnostics->Get({filePath, content}); + json diags = m_diagnostics->GetDiagnostics({filePath, content}); m_outQueue.push( {{"method", "textDocument/publishDiagnostics"}, {"params", diff --git a/src/main.cpp b/src/main.cpp index bf0ea69..28db961 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -2,7 +2,7 @@ // main.cpp // opencl-language-server // -// Created by Ilya Shoshin (Galarius) on 7/14/21. +// Created by Ilia Shoshin on 7/14/21. // #include @@ -30,6 +30,112 @@ using namespace ocls; namespace { +struct SubCommand +{ + SubCommand(CLI::App& app, std::string name, std::string description) : cmd {app.add_subcommand(name, description)} + {} + + ~SubCommand() = default; + + bool IsParsed() const + { + return cmd->parsed(); + } + +protected: + CLI::App* cmd; +}; + +struct CLInfoSubCommand final : public SubCommand +{ + CLInfoSubCommand(CLI::App& app) : SubCommand(app, "clinfo", "Show information about available OpenCL devices") + { + cmd->add_flag("-p,--pretty-print", prettyPrint, "Enable pretty-printing"); + } + + int Execute(const std::shared_ptr& clinfo) + { + const auto jsonBody = clinfo->json(); + const int indentation = prettyPrint ? 4 : -1; + std::cout << jsonBody.dump(indentation) << std::endl; + return EXIT_SUCCESS; + } + +private: + bool prettyPrint = false; +}; + +struct DiagnosticsSubCommand final : public SubCommand +{ + DiagnosticsSubCommand(CLI::App& app) : SubCommand(app, "diagnostics", "Provides an OpenCL kernel diagnostics") + { + cmd->add_flag("-j,--json", json, "Print diagnostics in JSON format"); + cmd->add_option("-k,--kernel", kernel, "Path to a kernel file")->required(true); + cmd->add_option("-b,--build-options", buildOptions, "Options to be utilized when building the program.") + ->capture_default_str(); + cmd->add_option( + "-d,--device-id", + deviceID, + "Device ID or 0 (automatic selection) of the OpenCL device to be used for diagnostics.") + ->capture_default_str(); + cmd->add_option("--error-limit", maxNumberOfProblems, "The maximum number of errors parsed by the compiler.") + ->capture_default_str(); + } + + int Execute(const std::shared_ptr& diagnostics) + { + auto content = utils::ReadFileContent(kernel); + if (!content.has_value()) + { + return EXIT_FAILURE; + } + + try + { + if (!buildOptions.empty()) + { + diagnostics->SetBuildOptions(buildOptions); + } + + if (deviceID > 0) + { + diagnostics->SetOpenCLDevice(deviceID); + } + + if (maxNumberOfProblems != INT8_MAX) + { + diagnostics->SetMaxProblemsCount(maxNumberOfProblems); + } + + Source source {kernel, *content}; + if (json) + { + auto output = diagnostics->GetDiagnostics(source); + std::cout << output.dump(4) << std::endl; + } + else + { + auto output = diagnostics->GetBuildLog(source); + std::cout << output << std::endl; + } + } + catch (std::exception& err) + { + std::cerr << "Failed to get diagnostics: " << err.what() << std::endl; + return EXIT_FAILURE; + } + + return EXIT_SUCCESS; + } + +private: + std::string kernel; + std::string buildOptions; + uint32_t deviceID = 0; + uint64_t maxNumberOfProblems = INT8_MAX; + bool json = false; +}; + std::shared_ptr server; static void SignalHandler(int) @@ -77,10 +183,12 @@ inline void SetupBinaryStreamMode() { #if defined(WIN32) // to handle CRLF - if (_setmode(_fileno(stdin), _O_BINARY) == -1) { + if (_setmode(_fileno(stdin), _O_BINARY) == -1) + { spdlog::error("Cannot set stdin mode to _O_BINARY"); } - if (_setmode(_fileno(stdout), _O_BINARY) == -1) { + if (_setmode(_fileno(stdout), _O_BINARY) == -1) + { spdlog::error("Cannot set stdout mode to _O_BINARY"); } #endif @@ -94,11 +202,10 @@ int main(int argc, char* argv[]) std::string optLogFile = "opencl-language-server.log"; spdlog::level::level_enum optLogLevel = spdlog::level::trace; - CLI::App app { - "OpenCL Language Server\n" - "The language server communicates with a client using JSON-RPC protocol.\n" - "You can stop the server by sending an interrupt signal followed by any character sent to standard input.\n" - }; + CLI::App app {"OpenCL Language Server\n" + "The language server communicates with a client using JSON-RPC protocol.\n" + "Stop the server by sending an interrupt signal followed by any character sent to standard input.\n" + "Optionally, you can use subcommands to access functionality without starting the server.\n"}; app.add_flag("-e,--enable-file-logging", flagLogTofile, "Enable file logging"); app.add_option("-f,--log-file", optLogFile, "Path to log file")->required(false)->capture_default_str(); app.add_option("-l,--log-level", optLogLevel, "Log level") @@ -115,33 +222,31 @@ int main(int argc, char* argv[]) "-v,--version", []() { std::cout << ocls::version << std::endl; - exit(0); + exit(EXIT_SUCCESS); }, "Show version"); - - bool flagPrettyPrint = false; - auto clinfoCmd = app.add_subcommand("clinfo", "Show information about available OpenCL devices"); - clinfoCmd->add_flag("-p,--pretty-print", flagPrettyPrint, "Enable pretty-printing"); + CLInfoSubCommand clInfoCmd(app); + DiagnosticsSubCommand diagnosticsCmd(app); CLI11_PARSE(app, argc, argv); - ConfigureLogging(flagLogTofile, optLogFile, optLogLevel); auto clinfo = CreateCLInfo(); - if (*clinfoCmd) + if (clInfoCmd.IsParsed()) { - const auto jsonBody = clinfo->json(); - const int indentation = flagPrettyPrint ? 4 : -1; - std::cout << jsonBody.dump(indentation) << std::endl; - return 0; + return clInfoCmd.Execute(clinfo); } - SetupBinaryStreamMode(); + auto diagnostics = CreateDiagnostics(clinfo); + if (diagnosticsCmd.IsParsed()) + { + return diagnosticsCmd.Execute(diagnostics); + } + SetupBinaryStreamMode(); std::signal(SIGINT, SignalHandler); auto jrpc = CreateJsonRPC(); - auto diagnostics = CreateDiagnostics(clinfo); server = CreateLSPServer(jrpc, diagnostics); return server->Run(); } diff --git a/src/utils.cpp b/src/utils.cpp index 44190e8..0c9e6b3 100644 --- a/src/utils.cpp +++ b/src/utils.cpp @@ -9,6 +9,7 @@ #include #include +#include #include #include #include @@ -121,6 +122,22 @@ bool EndsWith(const std::string& str, const std::string& suffix) return str.size() >= suffix.size() && 0 == str.compare(str.size() - suffix.size(), suffix.size(), suffix); } +std::optional ReadFileContent(std::string_view fileName) +{ + std::string content; + std::ifstream file(fileName); + if (file.is_open()) { + std::stringstream buffer; + buffer << file.rdbuf(); + file.close(); + content = buffer.str(); + } else { + spdlog::error("Unable to open file '{}'", fileName); + return std::nullopt; + } + return content; +} + namespace internal { // Generates a lookup table for the checksums of all 8-bit values. std::array GenerateCRCLookupTable() diff --git a/tests/lsp-event-handler-tests.cpp b/tests/lsp-event-handler-tests.cpp index 09e7c0b..e03af31 100644 --- a/tests/lsp-event-handler-tests.cpp +++ b/tests/lsp-event-handler-tests.cpp @@ -259,8 +259,8 @@ TEST_F(LSPTest, BuildDiagnosticsRespond_shouldBuildResponse) auto expectedDiagnostics = GetTestDiagnostics(uri); auto expectedResponse = GetTestDiagnosticsResponse(uri); - ON_CALL(*mockDiagnostics, Get(testing::_)).WillByDefault(::testing::Return(expectedDiagnostics)); - EXPECT_CALL(*mockDiagnostics, Get(Source {uri, content})).Times(1); + ON_CALL(*mockDiagnostics, GetDiagnostics(testing::_)).WillByDefault(::testing::Return(expectedDiagnostics)); + EXPECT_CALL(*mockDiagnostics, GetDiagnostics(Source {uri, content})).Times(1); handler->BuildDiagnosticsRespond(uri, content); auto response = handler->GetNextResponse(); @@ -274,8 +274,8 @@ TEST_F(LSPTest, BuildDiagnosticsRespond_withException_shouldReplyWithError) auto [uri, content] = GetTestSource(); Source expectedSource {uri, content}; - ON_CALL(*mockDiagnostics, Get(testing::_)).WillByDefault(::testing::Throw(std::runtime_error("Exception"))); - EXPECT_CALL(*mockDiagnostics, Get(expectedSource)).Times(1); + ON_CALL(*mockDiagnostics, GetDiagnostics(testing::_)).WillByDefault(::testing::Throw(std::runtime_error("Exception"))); + EXPECT_CALL(*mockDiagnostics, GetDiagnostics(expectedSource)).Times(1); EXPECT_CALL(*mockJsonRPC, WriteError(JRPCErrorCode::InternalError, "Failed to get diagnostics: Exception")) .Times(1); @@ -291,8 +291,8 @@ TEST_F(LSPTest, OnTextOpen_shouldBuildResponse) auto expectedResponse = GetTestDiagnosticsResponse(uri); nlohmann::json request = {{"params", {{"textDocument", {{"uri", uri}, {"text", content}}}}}}; - ON_CALL(*mockDiagnostics, Get(testing::_)).WillByDefault(::testing::Return(expectedDiagnostics)); - EXPECT_CALL(*mockDiagnostics, Get(Source {uri, content})).Times(1); + ON_CALL(*mockDiagnostics, GetDiagnostics(testing::_)).WillByDefault(::testing::Return(expectedDiagnostics)); + EXPECT_CALL(*mockDiagnostics, GetDiagnostics(Source {uri, content})).Times(1); handler->OnTextOpen(request); auto response = handler->GetNextResponse(); @@ -316,8 +316,8 @@ TEST_F(LSPTest, OnTextChanged_shouldBuildResponse) }}, {"contentChanges", {{{"text", content}}}}}}}; - ON_CALL(*mockDiagnostics, Get(testing::_)).WillByDefault(::testing::Return(expectedDiagnostics)); - EXPECT_CALL(*mockDiagnostics, Get(Source {uri, content})).Times(1); + ON_CALL(*mockDiagnostics, GetDiagnostics(testing::_)).WillByDefault(::testing::Return(expectedDiagnostics)); + EXPECT_CALL(*mockDiagnostics, GetDiagnostics(Source {uri, content})).Times(1); handler->OnTextChanged(request); auto response = handler->GetNextResponse(); diff --git a/tests/mocks/diagnostics-mock.hpp b/tests/mocks/diagnostics-mock.hpp index a85c4cb..0702f2e 100644 --- a/tests/mocks/diagnostics-mock.hpp +++ b/tests/mocks/diagnostics-mock.hpp @@ -14,9 +14,13 @@ class DiagnosticsMock : public ocls::IDiagnostics public: MOCK_METHOD(void, SetBuildOptions, (const nlohmann::json&), (override)); + MOCK_METHOD(void, SetBuildOptions, (const std::string&), (override)); + MOCK_METHOD(void, SetMaxProblemsCount, (uint64_t), (override)); MOCK_METHOD(void, SetOpenCLDevice, (uint32_t), (override)); - MOCK_METHOD(nlohmann::json, Get, (const ocls::Source&), (override)); + MOCK_METHOD(std::string, GetBuildLog, (const ocls::Source&), (override)); + + MOCK_METHOD(nlohmann::json, GetDiagnostics, (const ocls::Source&), (override)); }; From 9679a06bf9a16abb42781f09e3ca9dbfa4360129 Mon Sep 17 00:00:00 2001 From: Galarius Date: Fri, 11 Aug 2023 23:25:25 +0300 Subject: [PATCH 19/32] [#19] Add uriparser dependency to convert URI to file path --- CMakeLists.txt | 3 ++- conanfile.py | 1 + include/utils.hpp | 2 +- src/lsp.cpp | 2 +- src/utils.cpp | 64 ++++++++++++++++++-------------------------- tests/CMakeLists.txt | 2 +- 6 files changed, 32 insertions(+), 42 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index b57aa2c..88696e6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -47,6 +47,7 @@ find_package(OpenCLHeadersCpp REQUIRED) find_package(spdlog REQUIRED) find_package(nlohmann_json REQUIRED) find_package(CLI11 REQUIRED) +find_package(uriparser REQUIRED) if(ENABLE_TESTING) find_package(GTest REQUIRED) @@ -75,7 +76,7 @@ list(TRANSFORM sources PREPEND "${CMAKE_CURRENT_SOURCE_DIR}/src/") configure_file(${CMAKE_CURRENT_SOURCE_DIR}/include/version.hpp.in version.hpp) source_group("include" FILES ${headers}) source_group("src" FILES ${sources}) -set(libs nlohmann_json::nlohmann_json spdlog::spdlog OpenCL::HeadersCpp CLI11::CLI11) +set(libs nlohmann_json::nlohmann_json spdlog::spdlog OpenCL::HeadersCpp CLI11::CLI11 uriparser::uriparser) if(LINUX) set(libs ${libs} stdc++fs OpenCL::OpenCL) elseif(APPLE) diff --git a/conanfile.py b/conanfile.py index af04d32..d529140 100644 --- a/conanfile.py +++ b/conanfile.py @@ -24,6 +24,7 @@ class OpenCLLanguageServerConanfile(ConanFile): "nlohmann_json/[^3.11.2]", "opencl-clhpp-headers/2022.09.30", "spdlog/[^1.11.0]", + "uriparser/[^0.9.7]" ) exports_sources = ( "include/**", diff --git a/include/utils.hpp b/include/utils.hpp index 330f1d7..d4792a9 100644 --- a/include/utils.hpp +++ b/include/utils.hpp @@ -37,7 +37,7 @@ std::shared_ptr CreateDefaultExitHandler(); void Trim(std::string& s); std::vector SplitString(const std::string& str, const std::string& pattern); -std::string UriToPath(const std::string& uri); +std::string UriToFilePath(const std::string& uri); bool EndsWith(const std::string& str, const std::string& suffix); std::optional ReadFileContent(std::string_view fileName); diff --git a/src/lsp.cpp b/src/lsp.cpp index d02048e..9442757 100644 --- a/src/lsp.cpp +++ b/src/lsp.cpp @@ -241,7 +241,7 @@ void LSPServerEventsHandler::BuildDiagnosticsRespond(const std::string &uri, con { try { - const auto filePath = utils::UriToPath(uri); + const auto filePath = utils::UriToFilePath(uri); spdlog::get(logger)->debug("Converted uri '{}' to path '{}'", uri, filePath); json diags = m_diagnostics->GetDiagnostics({filePath, content}); diff --git a/src/utils.cpp b/src/utils.cpp index 0c9e6b3..2d6f4e0 100644 --- a/src/utils.cpp +++ b/src/utils.cpp @@ -16,6 +16,8 @@ #include #include +#include + namespace ocls::utils { class DefaultGenerator final : public IGenerator @@ -78,43 +80,26 @@ std::vector SplitString(const std::string& str, const std::string& return result; } -// Limited file uri -> path converter -std::string UriToPath(const std::string& uri) +std::string UriToFilePath(const std::string& uri) { - try - { - std::string str = uri; - auto pos = str.find("file://"); - if (pos != std::string::npos) - str.replace(pos, 7, ""); - do - { - pos = str.find("%3A"); - if (pos != std::string::npos) - str.replace(pos, 3, ":"); - } while (pos != std::string::npos); - - do - { - pos = str.find("%20"); - if (pos != std::string::npos) - str.replace(pos, 3, " "); - } while (pos != std::string::npos); - + const size_t bytesNeeded = 8 + 3 * uri.length() + 1; + char* fileName = (char*)malloc(bytesNeeded * sizeof(char)); #if defined(WIN32) - // remove first / - if (str.rfind("/", 0) == 0) - { - str.replace(0, 1, ""); - } -#endif - return str; + if (uriUriStringToWindowsFilenameA(uri.c_str(), fileName) != URI_SUCCESS) + { + free(fileName); + throw std::runtime_error("Failed to convert URI to Windows filename."); } - catch (std::exception& e) +#else + if (uriUriStringToUnixFilenameA(uri.c_str(), fileName) != URI_SUCCESS) { - spdlog::error("Failed to convert file uri to path, {}", e.what()); + free(fileName); + throw std::runtime_error("Failed to convert URI to Unix filename."); } - return uri; +#endif + std::string result(fileName); + free(fileName); + return result; } bool EndsWith(const std::string& str, const std::string& suffix) @@ -126,12 +111,15 @@ std::optional ReadFileContent(std::string_view fileName) { std::string content; std::ifstream file(fileName); - if (file.is_open()) { - std::stringstream buffer; - buffer << file.rdbuf(); - file.close(); - content = buffer.str(); - } else { + if (file.is_open()) + { + std::stringstream buffer; + buffer << file.rdbuf(); + file.close(); + content = buffer.str(); + } + else + { spdlog::error("Unable to open file '{}'", fileName); return std::nullopt; } diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 74f8cba..b0407cc 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -10,7 +10,7 @@ set(test_sources lsp-event-handler-tests.cpp main.cpp ) -set(libs GTest::gmock nlohmann_json::nlohmann_json spdlog::spdlog OpenCL::HeadersCpp) +set(libs GTest::gmock nlohmann_json::nlohmann_json spdlog::spdlog OpenCL::HeadersCpp uriparser::uriparser) if(LINUX) set(libs ${libs} stdc++fs OpenCL::OpenCL) elseif(APPLE) From 9436274751a065d5dbf71edc025bad18d63d6a89 Mon Sep 17 00:00:00 2001 From: Galarius Date: Fri, 11 Aug 2023 23:37:56 +0300 Subject: [PATCH 20/32] Update test_package to support 'clinfo' subcommand --- test_package/conanfile.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test_package/conanfile.py b/test_package/conanfile.py index c71246d..f4eb3b6 100644 --- a/test_package/conanfile.py +++ b/test_package/conanfile.py @@ -27,6 +27,6 @@ def test(self): assert len(output.getvalue()) > 0 output = StringIO() - self.run(f"{opencl_ls} --clinfo", output) + self.run(f"{opencl_ls} clinfo", output) clinfo = json.loads(output.getvalue()) print(json.dumps(clinfo, indent=2)) From c5d3e2613777795bd4f3fd7d056b15cce65a6756 Mon Sep 17 00:00:00 2001 From: Galarius Date: Sat, 12 Aug 2023 23:03:06 +0300 Subject: [PATCH 21/32] Refactor Diagnostics class --- include/clinfo.hpp | 6 +- include/diagnostics.hpp | 2 +- src/clinfo.cpp | 84 ++++++++++++++++++++++---- src/diagnostics.cpp | 130 +++++++++++++++++++--------------------- 4 files changed, 141 insertions(+), 81 deletions(-) diff --git a/include/clinfo.hpp b/include/clinfo.hpp index 07415c8..67bb2d9 100644 --- a/include/clinfo.hpp +++ b/include/clinfo.hpp @@ -2,7 +2,7 @@ // clinfo.hpp // opencl-language-server // -// Created by is on 5.2.2023. +// Created by Ilia Shoshin on 5.2.2023. // #pragma once @@ -21,9 +21,13 @@ struct ICLInfo virtual nlohmann::json json() = 0; + virtual std::vector GetDevices() = 0; + virtual uint32_t GetDeviceID(const cl::Device& device) = 0; virtual std::string GetDeviceDescription(const cl::Device& device) = 0; + + virtual size_t GetDevicePowerIndex(const cl::Device& device) = 0; }; std::shared_ptr CreateCLInfo(); diff --git a/include/diagnostics.hpp b/include/diagnostics.hpp index 17fcb89..f564d27 100644 --- a/include/diagnostics.hpp +++ b/include/diagnostics.hpp @@ -2,7 +2,7 @@ // diagnostics.hpp // opencl-language-server // -// Created by Ilya Shoshin (Galarius) on 7/16/21. +// Created by Ilia Shoshin on 7/16/21. // #pragma once diff --git a/src/clinfo.cpp b/src/clinfo.cpp index 98293e1..c1f3922 100644 --- a/src/clinfo.cpp +++ b/src/clinfo.cpp @@ -2,7 +2,7 @@ // clinfo.cpp // opencl-language-server // -// Created by is on 5.2.2023. +// Created by Ilia Shoshin on 5.2.2023. // #include "clinfo.hpp" @@ -475,6 +475,45 @@ class CLInfo final : public ICLInfo return nlohmann::json {{"PLATFORMS", jsonPlatforms}}; } + std::vector GetPlatforms() + { + std::vector platforms; + try + { + cl::Platform::get(&platforms); + spdlog::get(logger)->info("Found OpenCL platforms: {}", platforms.size()); + } + catch (cl::Error& err) + { + spdlog::get(logger)->error("No OpenCL platforms were found, {}", err.what()); + } + return platforms; + } + + std::vector GetDevices() + { + std::vector devices; + const auto platforms = GetPlatforms(); + for (const auto& platform : platforms) + { + try + { + std::vector platformDevices; + platform.getDevices(CL_DEVICE_TYPE_ALL, &platformDevices); + spdlog::get(logger)->info("Found OpenCL devices: {}", platformDevices.size()); + devices.insert( + devices.end(), + std::make_move_iterator(platformDevices.begin()), + std::make_move_iterator(platformDevices.end())); + } + catch (cl::Error& err) + { + spdlog::get(logger)->error("No OpenCL devices were found, {}", err.what()); + } + } + return devices; + } + uint32_t GetDeviceID(const cl::Device& device) { return CalculateDeviceID(device); @@ -482,16 +521,39 @@ class CLInfo final : public ICLInfo std::string GetDeviceDescription(const cl::Device& device) { - auto name = device.getInfo(); - auto type = device.getInfo(); - auto version = device.getInfo(); - auto vendor = device.getInfo(); - auto vendorID = device.getInfo(); - auto driverVersion = device.getInfo(); - auto description = "name: " + std::move(name) + "; " + "type: " + std::to_string(type) + "; " + - "version: " + std::move(version) + "; " + "vendor: " + std::move(vendor) + "; " + - "vendorID: " + std::to_string(vendorID) + "; " + "driverVersion: " + std::move(driverVersion); - return description; + try + { + auto name = device.getInfo(); + auto type = device.getInfo(); + auto version = device.getInfo(); + auto vendor = device.getInfo(); + auto vendorID = device.getInfo(); + auto driverVersion = device.getInfo(); + auto description = "name: " + std::move(name) + "; " + "type: " + std::to_string(type) + "; " + + "version: " + std::move(version) + "; " + "vendor: " + std::move(vendor) + "; " + + "vendorID: " + std::to_string(vendorID) + "; " + "driverVersion: " + std::move(driverVersion); + return description; + } + catch (cl::Error& err) + { + spdlog::get(logger)->error("Failed to get description for the selected device, {}", err.what()); + } + return "unknown"; + } + + size_t GetDevicePowerIndex(const cl::Device& device) + { + try + { + const size_t maxComputeUnits = device.getInfo(); + const size_t maxClockFrequency = device.getInfo(); + return maxComputeUnits * maxClockFrequency; + } + catch (const cl::Error& err) + { + spdlog::get(logger)->error("Failed to get power index for the device, {}", err.what()); + } + return 0; } }; diff --git a/src/diagnostics.cpp b/src/diagnostics.cpp index e20b862..51c6ba0 100644 --- a/src/diagnostics.cpp +++ b/src/diagnostics.cpp @@ -2,7 +2,7 @@ // diagnostics.cpp // opencl-language-server // -// Created by Ilya Shoshin (Galarius) on 7/16/21. +// Created by Ilia Shoshin on 7/16/21. // #include "diagnostics.hpp" @@ -46,13 +46,6 @@ std::tuple ParseOutput(const std::sm return std::make_tuple(std::move(source), line, col, severity, std::move(message)); } -size_t GetDevicePowerIndex(const cl::Device& device) -{ - const size_t maxComputeUnits = device.getInfo(); - const size_t maxClockFrequency = device.getInfo(); - return maxComputeUnits * maxClockFrequency; -} - } // namespace namespace ocls { @@ -70,6 +63,8 @@ class Diagnostics final : public IDiagnostics nlohmann::json GetDiagnostics(const Source& source); private: + std::optional SelectOpenCLDevice(const std::vector& devices, uint32_t identifier); + std::optional SelectOpenCLDeviceByPowerIndex(const std::vector& devices); nlohmann::json BuildDiagnostics(const std::string& buildLog, const std::string& name); std::string BuildSource(const std::string& source) const; @@ -86,78 +81,76 @@ Diagnostics::Diagnostics(std::shared_ptr clInfo) : m_clInfo {std::move( SetOpenCLDevice(0); } -void Diagnostics::SetOpenCLDevice(uint32_t identifier) +std::optional Diagnostics::SelectOpenCLDeviceByPowerIndex(const std::vector& devices) { - spdlog::get(logger)->trace("Selecting OpenCL platform..."); - std::vector platforms; - try - { - cl::Platform::get(&platforms); - } - catch (cl::Error& err) - { - spdlog::get(logger)->error("No OpenCL platforms were found, {}", err.what()); - } + auto maxIt = std::max_element(devices.begin(), devices.end(), [this](const cl::Device& a, const cl::Device& b) { + const auto powerIndexA = m_clInfo->GetDevicePowerIndex(a); + const auto powerIndexB = m_clInfo->GetDevicePowerIndex(b); + return powerIndexA < powerIndexB; + }); - spdlog::get(logger)->info("Found OpenCL platforms: {}", platforms.size()); - if (platforms.size() == 0) + if (maxIt == devices.end()) { - return; + return std::nullopt; } - std::string description; + return *maxIt; +} + +std::optional Diagnostics::SelectOpenCLDevice(const std::vector& devices, uint32_t identifier) +{ + auto log = spdlog::get(logger); std::optional selectedDevice; - for (auto& platform : platforms) - { - std::vector devices; + + // Find device by identifier + auto it = std::find_if(devices.begin(), devices.end(), [this, &identifier](const cl::Device& device) { try { - platform.getDevices(CL_DEVICE_TYPE_ALL, &devices); + return m_clInfo->GetDeviceID(device) == identifier; } - catch (cl::Error& err) + catch (const cl::Error&) { - spdlog::get(logger)->error("No OpenCL devices were found, {}", err.what()); - } - spdlog::get(logger)->info("Found OpenCL devices: {}", devices.size()); - if (devices.size() == 0) - { - continue; + return false; } + }); - size_t maxPowerIndex = 0; - spdlog::get(logger)->trace("Selecting OpenCL device (total: {})...", devices.size()); - for (auto& device : devices) - { - size_t powerIndex = 0; - try - { - description = m_clInfo->GetDeviceDescription(device); - auto deviceID = m_clInfo->GetDeviceID(device); - if (identifier == deviceID) - { - selectedDevice = device; - break; - } - else - { - powerIndex = GetDevicePowerIndex(device); - } - } - catch (cl::Error& err) - { - spdlog::get(logger)->error("Failed to get info for a device, {}", err.what()); - continue; - } - - if (powerIndex > maxPowerIndex) - { - maxPowerIndex = powerIndex; - selectedDevice = device; - } - } + if (it != devices.end()) + { + return *it; + } + + // If device is not found by identifier, then find the device based on power index + auto device = SelectOpenCLDeviceByPowerIndex(devices); + if (device && (!m_device || m_clInfo->GetDevicePowerIndex(*device) > m_clInfo->GetDevicePowerIndex(*m_device))) + { + selectedDevice = device; + } + + return selectedDevice; +} + +void Diagnostics::SetOpenCLDevice(uint32_t identifier) +{ + auto log = spdlog::get(logger); + log->trace("Selecting OpenCL device..."); + + const auto devices = m_clInfo->GetDevices(); + + if (devices.size() == 0) + { + return; + } + + m_device = SelectOpenCLDevice(devices, identifier); + + if (!m_device) + { + log->warn("No suitable OpenCL device was found."); + return; } - m_device = selectedDevice; - spdlog::get(logger)->info("Selected OpenCL device: {}", description); + + auto description = m_clInfo->GetDeviceDescription(*m_device); + log->info("Selected OpenCL device: {}", description); } std::string Diagnostics::BuildSource(const std::string& source) const @@ -180,7 +173,8 @@ std::string Diagnostics::BuildSource(const std::string& source) const { if (err.err() != CL_BUILD_PROGRAM_FAILURE) { - spdlog::get(logger)->error("Failed to build program, error, {}", err.what()); + spdlog::get(logger)->error("Failed to build program: {} ({})", err.what(), err.err()); + throw err; } } From 68ad1226ef4194a5b9ee270d1290071cf939a2c7 Mon Sep 17 00:00:00 2001 From: Galarius Date: Sun, 13 Aug 2023 22:55:25 +0300 Subject: [PATCH 22/32] Create DiagnosticsParser class & tests --- include/diagnostics.hpp | 14 +++ src/diagnostics.cpp | 153 ++++++++++++----------- tests/CMakeLists.txt | 26 +++- tests/diagnostics-parser-tests.cpp | 193 +++++++++++++++++++++++++++++ 4 files changed, 311 insertions(+), 75 deletions(-) create mode 100644 tests/diagnostics-parser-tests.cpp diff --git a/include/diagnostics.hpp b/include/diagnostics.hpp index f564d27..3a39d9e 100644 --- a/include/diagnostics.hpp +++ b/include/diagnostics.hpp @@ -12,6 +12,8 @@ #include #include #include +#include +#include namespace ocls { @@ -26,6 +28,15 @@ struct Source } }; +struct IDiagnosticsParser +{ + virtual ~IDiagnosticsParser() = default; + + virtual std::tuple ParseMatch(const std::smatch& matches) = 0; + virtual nlohmann::json ParseDiagnostics( + const std::string& buildLog, const std::string& name, uint64_t problemsLimit) = 0; +}; + struct IDiagnostics { virtual ~IDiagnostics() = default; @@ -39,6 +50,9 @@ struct IDiagnostics virtual nlohmann::json GetDiagnostics(const Source& source) = 0; }; +std::shared_ptr CreateDiagnosticsParser(); std::shared_ptr CreateDiagnostics(std::shared_ptr clInfo); +std::shared_ptr CreateDiagnostics( + std::shared_ptr clInfo, std::shared_ptr parser); } // namespace ocls diff --git a/src/diagnostics.cpp b/src/diagnostics.cpp index 51c6ba0..a6e7983 100644 --- a/src/diagnostics.cpp +++ b/src/diagnostics.cpp @@ -13,10 +13,9 @@ #include #include #include -#include #include #include // std::runtime_error, std::invalid_argument -#include +#include using namespace nlohmann; @@ -24,36 +23,81 @@ namespace { constexpr char logger[] = "diagnostics"; -int ParseSeverity(const std::string& severity) +} // namespace + +namespace ocls { + +// - DiagnosticsParser + +class DiagnosticsParser final : public IDiagnosticsParser { - if (severity == "error") - return 1; - else if (severity == "warning") - return 2; - else + std::regex m_regex {"^(.*):(\\d+):(\\d+): ((fatal )?error|warning|Scholar): (.*)$"}; + +public: + int ParseSeverity(const std::string& severity) + { + if (severity == "warning") + { + return 2; + } + else if (utils::EndsWith(severity, "error")) + { + return 1; + } return -1; -} + } -// :13:5: warning: no previous prototype for function 'getChannel' -std::tuple ParseOutput(const std::smatch& matches) -{ - std::string source = matches[1]; - const long line = std::stoi(matches[2]) - 1; // LSP assumes 0-indexed lines - const long col = std::stoi(matches[3]); - const int severity = ParseSeverity(matches[4]); - // matches[5] - 'fatal' - std::string message = matches[6]; - return std::make_tuple(std::move(source), line, col, severity, std::move(message)); -} + // example input: :13:5: warning: no previous prototype for function 'getChannel' + std::tuple ParseMatch(const std::smatch& matches) + { + std::string source = matches[1]; + const long line = std::stoi(matches[2]) - 1; // LSP assumes 0-indexed lines + const long col = std::stoi(matches[3]); + const int severity = ParseSeverity(matches[4]); + // matches[5] - 'fatal ' + std::string message = matches[6]; + return std::make_tuple(std::move(source), line, col, severity, std::move(message)); + } -} // namespace + nlohmann::json CreateDiagnostic(const std::smatch& matches, const std::string& name) + { + auto [source, line, col, severity, message] = ParseMatch(matches); + return { + {"source", name.empty() ? source : name}, + {"range", {{"start", {{"line", line}, {"character", col}}}, {"end", {{"line", line}, {"character", col}}}}}, + {"severity", severity}, + {"message", message}}; + } -namespace ocls { + nlohmann::json ParseDiagnostics(const std::string& buildLog, const std::string& name, uint64_t problemsLimit) + { + nlohmann::json diagnostics; + std::istringstream stream(buildLog); + std::string errLine; + uint64_t count = 0; + std::smatch matches; + while (std::getline(stream, errLine)) + { + if (std::regex_search(errLine, matches, m_regex) && matches.size() == 7) + { + if (count++ >= problemsLimit) + { + spdlog::get(logger)->info("Maximum number of problems reached, other problems will be skipped"); + break; + } + diagnostics.emplace_back(CreateDiagnostic(matches, name)); + } + } + return diagnostics; + } +}; + +// - Diagnostics class Diagnostics final : public IDiagnostics { public: - explicit Diagnostics(std::shared_ptr clInfo); + Diagnostics(std::shared_ptr clInfo, std::shared_ptr parser); void SetBuildOptions(const nlohmann::json& options); void SetBuildOptions(const std::string& options); @@ -65,18 +109,19 @@ class Diagnostics final : public IDiagnostics private: std::optional SelectOpenCLDevice(const std::vector& devices, uint32_t identifier); std::optional SelectOpenCLDeviceByPowerIndex(const std::vector& devices); - nlohmann::json BuildDiagnostics(const std::string& buildLog, const std::string& name); std::string BuildSource(const std::string& source) const; private: std::shared_ptr m_clInfo; + std::shared_ptr m_parser; std::optional m_device; - std::regex m_regex {"^(.*):(\\d+):(\\d+): ((fatal )?error|warning|Scholar): (.*)$"}; std::string m_BuildOptions; uint64_t m_maxNumberOfProblems = INT8_MAX; }; -Diagnostics::Diagnostics(std::shared_ptr clInfo) : m_clInfo {std::move(clInfo)} +Diagnostics::Diagnostics(std::shared_ptr clInfo, std::shared_ptr parser) + : m_clInfo {std::move(clInfo)} + , m_parser {std::move(parser)} { SetOpenCLDevice(0); } @@ -192,47 +237,6 @@ std::string Diagnostics::BuildSource(const std::string& source) const return build_log; } -nlohmann::json Diagnostics::BuildDiagnostics(const std::string& buildLog, const std::string& name) -{ - std::smatch matches; - auto errorLines = utils::SplitString(buildLog, "\n"); - json diagnostics; - uint64_t count = 0; - for (auto errLine : errorLines) - { - std::regex_search(errLine, matches, m_regex); - if (matches.size() != 7) - continue; - - if (count++ > m_maxNumberOfProblems) - { - spdlog::get(logger)->info("Maximum number of problems reached, other problems will be slipped"); - break; - } - - auto [source, line, col, severity, message] = ParseOutput(matches); - json diagnostic; - json range { - {"start", - { - {"line", line}, - {"character", col}, - }}, - {"end", - { - {"line", line}, - {"character", col}, - }}, - }; - diagnostic["source"] = name.empty() ? source : name; - diagnostic["range"] = range; - diagnostic["severity"] = severity; - diagnostic["message"] = message; - diagnostics.emplace_back(diagnostic); - } - return diagnostics; -} - std::string Diagnostics::GetBuildLog(const Source& source) { if (!m_device.has_value()) @@ -254,7 +258,7 @@ nlohmann::json Diagnostics::GetDiagnostics(const Source& source) } buildLog = BuildSource(source.text); spdlog::get(logger)->trace("BuildLog:\n{}", buildLog); - return BuildDiagnostics(buildLog, srcName); + return m_parser->ParseDiagnostics(buildLog, srcName, m_maxNumberOfProblems); } void Diagnostics::SetBuildOptions(const json& options) @@ -283,9 +287,20 @@ void Diagnostics::SetMaxProblemsCount(uint64_t maxNumberOfProblems) m_maxNumberOfProblems = maxNumberOfProblems; } +std::shared_ptr CreateDiagnosticsParser() +{ + return std::make_shared(); +} + std::shared_ptr CreateDiagnostics(std::shared_ptr clInfo) { - return std::make_shared(std::move(clInfo)); + return std::make_shared(std::move(clInfo), CreateDiagnosticsParser()); +} + +std::shared_ptr CreateDiagnostics( + std::shared_ptr clInfo, std::shared_ptr parser) +{ + return std::make_shared(std::move(clInfo), parser); } } // namespace ocls diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index b0407cc..509eac7 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -1,5 +1,6 @@ set(TESTS_PROJECT_NAME ${PROJECT_NAME}-tests) set(sources + diagnostics.cpp jsonrpc.cpp lsp.cpp utils.cpp @@ -7,6 +8,7 @@ set(sources list(TRANSFORM sources PREPEND "${PROJECT_SOURCE_DIR}/src/") set(test_sources jsonrpc-tests.cpp + diagnostics-parser-tests.cpp lsp-event-handler-tests.cpp main.cpp ) @@ -20,6 +22,23 @@ elseif(WIN32) endif() add_executable (${TESTS_PROJECT_NAME} ${sources} ${test_sources}) +target_compile_definitions(${TESTS_PROJECT_NAME} PRIVATE + CL_HPP_ENABLE_EXCEPTIONS + CL_HPP_CL_1_2_DEFAULT_BUILD +) +if(APPLE) +target_compile_definitions(${TESTS_PROJECT_NAME} PRIVATE + CL_HPP_MINIMUM_OPENCL_VERSION=120 + CL_HPP_TARGET_OPENCL_VERSION=120 +) +else() +target_compile_definitions(${TESTS_PROJECT_NAME} PRIVATE + CL_HPP_MINIMUM_OPENCL_VERSION=110 + CL_HPP_TARGET_OPENCL_VERSION=300 +) +endif() + +target_link_libraries (${TESTS_PROJECT_NAME} ${libs}) target_include_directories(${TESTS_PROJECT_NAME} PRIVATE "${PROJECT_SOURCE_DIR}/include" "${CMAKE_CURRENT_SOURCE_DIR}/mocks" @@ -27,10 +46,5 @@ target_include_directories(${TESTS_PROJECT_NAME} PRIVATE if(APPLE) target_include_directories(${TESTS_PROJECT_NAME} PRIVATE "${OpenCL_INCLUDE_DIRS}") endif() -target_link_libraries (${TESTS_PROJECT_NAME} ${libs}) -target_compile_definitions(${TESTS_PROJECT_NAME} PRIVATE - CL_HPP_CL_1_2_DEFAULT_BUILD - CL_HPP_TARGET_OPENCL_VERSION=120 - CL_HPP_MINIMUM_OPENCL_VERSION=120 -) + gtest_discover_tests(${TESTS_PROJECT_NAME}) \ No newline at end of file diff --git a/tests/diagnostics-parser-tests.cpp b/tests/diagnostics-parser-tests.cpp new file mode 100644 index 0000000..8c83177 --- /dev/null +++ b/tests/diagnostics-parser-tests.cpp @@ -0,0 +1,193 @@ +// +// diagnostics-parser-tests.cpp +// opencl-language-server-tests +// +// Created by Ilia Shoshin on 13/8/23. +// + +#include "diagnostics.hpp" + +#include +#include + + +using namespace ocls; +using namespace nlohmann; + +class DiagnosticsParserRegexTest + : public ::testing::TestWithParam> +{}; + +TEST_P(DiagnosticsParserRegexTest, CheckRegexParsing) +{ + auto [input, expectedSource, expectedLine, expectedCol, expectedSeverity, expectedMessage] = GetParam(); + std::regex r("^(.*):(\\d+):(\\d+): ((fatal )?error|warning|Scholar): (.*)$"); + std::smatch match; + EXPECT_TRUE(std::regex_search(input, match, r)); + + auto parser = CreateDiagnosticsParser(); + auto [source, line, col, severity, message] = parser->ParseMatch(match); + + EXPECT_EQ(source, expectedSource); + EXPECT_EQ(expectedLine, line); + EXPECT_EQ(expectedCol, col); + EXPECT_EQ(expectedSeverity, severity); + EXPECT_EQ(expectedMessage, message); +} + +INSTANTIATE_TEST_SUITE_P( + DiagnosticsParserTest, + DiagnosticsParserRegexTest, + ::testing::Values( + std::make_tuple( + ":12:5: warning: no previous prototype for function 'getChannel'", + "", + 11, + 5, + 2, + "no previous prototype for function 'getChannel'"), + std::make_tuple( + ":16:27: error: use of undeclared identifier 'r'", + "", + 15, + 27, + 1, + "use of undeclared identifier 'r'"), + std::make_tuple( + ":100:2: fatal error: unexpected end of file", + "", + 99, + 2, + 1, + "unexpected end of file"), + std::make_tuple( + ":5:14: Scholar: reference missing for citation", + "", + 4, + 14, + -1, + "reference missing for citation"))); + +TEST(ParseDiagnosticsTest, NoDiagnosticMessages) +{ + std::string log = "This is a regular log with no diagnostic message."; + auto parser = CreateDiagnosticsParser(); + auto result = parser->ParseDiagnostics(log, "TestName", 10); + EXPECT_TRUE(result.is_null()); +} + +TEST(ParseDiagnosticsTest, SingleDiagnosticMessage) +{ + std::string log = ":12:5: warning: no previous prototype for function 'getChannel'"; + nlohmann::json expectedResult = R"([ + { + "source": "TestName", + "range": { + "start": { + "line": 11, + "character": 5 + }, + "end": { + "line": 11, + "character": 5 + } + }, + "severity": 2, + "message": "no previous prototype for function 'getChannel'" + } + ])"_json; + auto parser = CreateDiagnosticsParser(); + auto result = parser->ParseDiagnostics(log, "TestName", 10); + EXPECT_EQ(result, expectedResult); +} + +TEST(ParseDiagnosticsTest, MultipleDiagnosticMessages) +{ + std::string log = ":12:5: warning: no previous prototype for function 'getChannel'\n" + ":16:27: error: use of undeclared identifier 'r'"; + nlohmann::json expectedResult = R"([ + { + "source": "TestName", + "range": { + "start": { + "line": 11, + "character": 5 + }, + "end": { + "line": 11, + "character": 5 + } + }, + "severity": 2, + "message": "no previous prototype for function 'getChannel'" + }, + { + "source": "TestName", + "range": { + "start": { + "line": 15, + "character": 27 + }, + "end": { + "line": 15, + "character": 27 + } + }, + "severity": 1, + "message": "use of undeclared identifier 'r'" + } + ])"_json; + auto parser = CreateDiagnosticsParser(); + auto result = parser->ParseDiagnostics(log, "TestName", 10); + EXPECT_EQ(result, expectedResult); +} + +TEST(ParseDiagnosticsTest, ExceedProblemsLimit) +{ + std::string log = ":12:5: warning: no previous prototype for function 'getChannel'\n" + ":16:27: error: use of undeclared identifier 'r'\n" + ":25:7: warning: no previous prototype for function 'quadric'\n"; + nlohmann::json expectedResult = R"([ + { + "source": "TestName", + "range": { + "start": { + "line": 11, + "character": 5 + }, + "end": { + "line": 11, + "character": 5 + } + }, + "severity": 2, + "message": "no previous prototype for function 'getChannel'" + }, + { + "source": "TestName", + "range": { + "start": { + "line": 15, + "character": 27 + }, + "end": { + "line": 15, + "character": 27 + } + }, + "severity": 1, + "message": "use of undeclared identifier 'r'" + } + ])"_json; + auto parser = CreateDiagnosticsParser(); + auto result = parser->ParseDiagnostics(log, "TestName", 2); + EXPECT_EQ(result, expectedResult); +} + +TEST(ParseDiagnosticsTest, MalformedDiagnosticMessage) +{ + std::string log = ":5:14: reference missing for citation"; + auto parser = CreateDiagnosticsParser(); + auto result = parser->ParseDiagnostics(log, "TestName", 10); + EXPECT_TRUE(result.is_null()); +} From a102020c301aa296c92bf2f2fd08063058adf8ee Mon Sep 17 00:00:00 2001 From: Galarius Date: Mon, 14 Aug 2023 23:04:30 +0300 Subject: [PATCH 23/32] Add 'diagnostics' tests --- src/diagnostics.cpp | 185 +++++++++++++++++++----------------- tests/CMakeLists.txt | 1 + tests/diagnostics-tests.cpp | 81 ++++++++++++++++ tests/mocks/clinfo-mock.hpp | 24 +++++ 4 files changed, 202 insertions(+), 89 deletions(-) create mode 100644 tests/diagnostics-tests.cpp create mode 100644 tests/mocks/clinfo-mock.hpp diff --git a/src/diagnostics.cpp b/src/diagnostics.cpp index a6e7983..5efa4b8 100644 --- a/src/diagnostics.cpp +++ b/src/diagnostics.cpp @@ -126,6 +126,84 @@ Diagnostics::Diagnostics(std::shared_ptr clInfo, std::shared_ptr() + " "; }; + auto opts = std::accumulate(options.begin(), options.end(), std::string(), concat); + SetBuildOptions(opts); + } + catch (std::exception& e) + { + spdlog::get(logger)->error("Failed to parse build options, {}", e.what()); + } +} + +void Diagnostics::SetBuildOptions(const std::string& options) +{ + m_BuildOptions = options; + spdlog::get(logger)->trace("Set build options, {}", m_BuildOptions); +} + +void Diagnostics::SetMaxProblemsCount(uint64_t maxNumberOfProblems) +{ + spdlog::get(logger)->trace("Set max number of problems: {}", maxNumberOfProblems); + m_maxNumberOfProblems = maxNumberOfProblems; +} + +void Diagnostics::SetOpenCLDevice(uint32_t identifier) +{ + auto log = spdlog::get(logger); + log->trace("Selecting OpenCL device..."); + + const auto devices = m_clInfo->GetDevices(); + + if (devices.size() == 0) + { + return; + } + + auto selectedDevice = SelectOpenCLDevice(devices, identifier); + if (!selectedDevice) + { + log->warn("No suitable OpenCL device was found."); + return; + } + + m_device = selectedDevice; + auto description = m_clInfo->GetDeviceDescription(*m_device); + log->info("Selected OpenCL device: {}", description); +} + +std::string Diagnostics::GetBuildLog(const Source& source) +{ + if (!m_device.has_value()) + { + throw std::runtime_error("missing OpenCL device"); + } + spdlog::get(logger)->trace("Getting diagnostics..."); + return BuildSource(source.text); +} + +nlohmann::json Diagnostics::GetDiagnostics(const Source& source) +{ + std::string buildLog = GetBuildLog(source); + std::string srcName; + if (!source.filePath.empty()) + { + auto filePath = std::filesystem::path(source.filePath).string(); + srcName = std::filesystem::path(filePath).filename().string(); + } + buildLog = BuildSource(source.text); + spdlog::get(logger)->trace("BuildLog:\n{}", buildLog); + return m_parser->ParseDiagnostics(buildLog, srcName, m_maxNumberOfProblems); +} + +// - + std::optional Diagnostics::SelectOpenCLDeviceByPowerIndex(const std::vector& devices) { auto maxIt = std::max_element(devices.begin(), devices.end(), [this](const cl::Device& a, const cl::Device& b) { @@ -145,57 +223,34 @@ std::optional Diagnostics::SelectOpenCLDeviceByPowerIndex(const std: std::optional Diagnostics::SelectOpenCLDevice(const std::vector& devices, uint32_t identifier) { auto log = spdlog::get(logger); - std::optional selectedDevice; - // Find device by identifier - auto it = std::find_if(devices.begin(), devices.end(), [this, &identifier](const cl::Device& device) { - try - { - return m_clInfo->GetDeviceID(device) == identifier; - } - catch (const cl::Error&) + if (identifier > 0) + { + auto it = std::find_if(devices.begin(), devices.end(), [this, &identifier](const cl::Device& device) { + try + { + return m_clInfo->GetDeviceID(device) == identifier; + } + catch (const cl::Error&) + { + return false; + } + }); + + if (it != devices.end()) { - return false; + return *it; } - }); - - if (it != devices.end()) - { - return *it; } // If device is not found by identifier, then find the device based on power index auto device = SelectOpenCLDeviceByPowerIndex(devices); if (device && (!m_device || m_clInfo->GetDevicePowerIndex(*device) > m_clInfo->GetDevicePowerIndex(*m_device))) { - selectedDevice = device; + return device; } - return selectedDevice; -} - -void Diagnostics::SetOpenCLDevice(uint32_t identifier) -{ - auto log = spdlog::get(logger); - log->trace("Selecting OpenCL device..."); - - const auto devices = m_clInfo->GetDevices(); - - if (devices.size() == 0) - { - return; - } - - m_device = SelectOpenCLDevice(devices, identifier); - - if (!m_device) - { - log->warn("No suitable OpenCL device was found."); - return; - } - - auto description = m_clInfo->GetDeviceDescription(*m_device); - log->info("Selected OpenCL device: {}", description); + return std::nullopt; } std::string Diagnostics::BuildSource(const std::string& source) const @@ -237,55 +292,7 @@ std::string Diagnostics::BuildSource(const std::string& source) const return build_log; } -std::string Diagnostics::GetBuildLog(const Source& source) -{ - if (!m_device.has_value()) - { - throw std::runtime_error("missing OpenCL device"); - } - spdlog::get(logger)->trace("Getting diagnostics..."); - return BuildSource(source.text); -} - -nlohmann::json Diagnostics::GetDiagnostics(const Source& source) -{ - std::string buildLog = GetBuildLog(source); - std::string srcName; - if (!source.filePath.empty()) - { - auto filePath = std::filesystem::path(source.filePath).string(); - srcName = std::filesystem::path(filePath).filename().string(); - } - buildLog = BuildSource(source.text); - spdlog::get(logger)->trace("BuildLog:\n{}", buildLog); - return m_parser->ParseDiagnostics(buildLog, srcName, m_maxNumberOfProblems); -} - -void Diagnostics::SetBuildOptions(const json& options) -{ - try - { - auto concat = [](const std::string& acc, const json& j) { return acc + j.get() + " "; }; - auto opts = std::accumulate(options.begin(), options.end(), std::string(), concat); - SetBuildOptions(opts); - } - catch (std::exception& e) - { - spdlog::get(logger)->error("Failed to parse build options, {}", e.what()); - } -} - -void Diagnostics::SetBuildOptions(const std::string& options) -{ - m_BuildOptions = options; - spdlog::get(logger)->trace("Set build options, {}", m_BuildOptions); -} - -void Diagnostics::SetMaxProblemsCount(uint64_t maxNumberOfProblems) -{ - spdlog::get(logger)->trace("Set max number of problems: {}", maxNumberOfProblems); - m_maxNumberOfProblems = maxNumberOfProblems; -} +// - std::shared_ptr CreateDiagnosticsParser() { diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 509eac7..c527f3e 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -9,6 +9,7 @@ list(TRANSFORM sources PREPEND "${PROJECT_SOURCE_DIR}/src/") set(test_sources jsonrpc-tests.cpp diagnostics-parser-tests.cpp + diagnostics-tests.cpp lsp-event-handler-tests.cpp main.cpp ) diff --git a/tests/diagnostics-tests.cpp b/tests/diagnostics-tests.cpp new file mode 100644 index 0000000..d9f5fb8 --- /dev/null +++ b/tests/diagnostics-tests.cpp @@ -0,0 +1,81 @@ +// +// diagnostics-tests.cpp +// opencl-language-server-tests +// +// Created by Ilia Shoshin on 14/8/23. +// + +#include "diagnostics.hpp" +#include "clinfo-mock.hpp" + +#include +#include +#include + + +using namespace ocls; +using namespace testing; + + +TEST(DiagnosticsTest, SelectDeviceBasedOnPowerIndexDuringTheCreation) +{ + auto mockCLInfo = std::make_shared(); + std::vector devices = {cl::Device(), cl::Device()}; + + EXPECT_CALL(*mockCLInfo, GetDevices()).WillOnce(Return(devices)); + EXPECT_CALL(*mockCLInfo, GetDevicePowerIndex(_)).WillOnce(Return(10)).WillOnce(Return(10)); + EXPECT_CALL(*mockCLInfo, GetDeviceDescription(_)).WillOnce(Return("")); + + CreateDiagnostics(mockCLInfo); +} + +TEST(DiagnosticsTest, SelectDeviceBasedOnPowerIndex) +{ + auto mockCLInfo = std::make_shared(); + std::vector devices = {cl::Device(), cl::Device()}; + + EXPECT_CALL(*mockCLInfo, GetDevices()).Times(2).WillRepeatedly(Return(devices)); + EXPECT_CALL(*mockCLInfo, GetDevicePowerIndex(_)) + .WillOnce(Return(10)) + .WillOnce(Return(20)) + .WillOnce(Return(10)) + .WillOnce(Return(20)) + .WillOnce(Return(20)) + .WillOnce(Return(20)); + EXPECT_CALL(*mockCLInfo, GetDeviceDescription(_)).WillOnce(Return("")); + auto diagnostics = CreateDiagnostics(mockCLInfo); + diagnostics->SetOpenCLDevice(0); +} + +TEST(DiagnosticsTest, SelectDeviceBasedOnExistingIndex) +{ + auto mockCLInfo = std::make_shared(); + std::vector devices = {cl::Device(), cl::Device()}; + + EXPECT_CALL(*mockCLInfo, GetDevices()).Times(2).WillRepeatedly(Return(devices)); + EXPECT_CALL(*mockCLInfo, GetDevicePowerIndex(_)).WillOnce(Return(10)).WillOnce(Return(20)); + EXPECT_CALL(*mockCLInfo, GetDeviceID(_)).WillOnce(Return(3138399603)).WillOnce(Return(2027288592)); + EXPECT_CALL(*mockCLInfo, GetDeviceDescription(_)).Times(2).WillRepeatedly(Return("")); + auto diagnostics = CreateDiagnostics(mockCLInfo); + diagnostics->SetOpenCLDevice(2027288592); +} + +TEST(DiagnosticsTest, SelectDeviceBasedOnNonExistingIndex) +{ + auto mockCLInfo = std::make_shared(); + std::vector devices = {cl::Device(), cl::Device()}; + + EXPECT_CALL(*mockCLInfo, GetDevices()).Times(2).WillRepeatedly(Return(devices)); + EXPECT_CALL(*mockCLInfo, GetDeviceID(_)).WillOnce(Return(3138399603)).WillOnce(Return(2027288592)); + EXPECT_CALL(*mockCLInfo, GetDevicePowerIndex(_)) + .WillOnce(Return(10)) + .WillOnce(Return(20)) + .WillOnce(Return(10)) + .WillOnce(Return(20)) + .WillOnce(Return(20)) + .WillOnce(Return(20)); + EXPECT_CALL(*mockCLInfo, GetDeviceDescription(_)).WillOnce(Return("")); + + auto diagnostics = CreateDiagnostics(mockCLInfo); + diagnostics->SetOpenCLDevice(4527288514); +} diff --git a/tests/mocks/clinfo-mock.hpp b/tests/mocks/clinfo-mock.hpp new file mode 100644 index 0000000..07a0535 --- /dev/null +++ b/tests/mocks/clinfo-mock.hpp @@ -0,0 +1,24 @@ +// +// clinfo-mock.hpp +// opencl-language-server-tests +// +// Created by Ilia Shoshin on 14/8/23. +// + +#include "clinfo.hpp" + +#include + +class CLInfoMock : public ocls::ICLInfo +{ +public: + MOCK_METHOD(nlohmann::json, json, (), (override)); + + MOCK_METHOD(std::vector, GetDevices, (), (override)); + + MOCK_METHOD(uint32_t, GetDeviceID, (const cl::Device&), (override)); + + MOCK_METHOD(std::string, GetDeviceDescription, (const cl::Device&), (override)); + + MOCK_METHOD(size_t, GetDevicePowerIndex, (const cl::Device&), (override)); +}; From 2fd144dec1121a909a8eeb9b24d9c374d2bd76e7 Mon Sep 17 00:00:00 2001 From: Galarius Date: Sun, 20 Aug 2023 22:17:31 +0300 Subject: [PATCH 24/32] Refactor utils, add tests --- include/utils.hpp | 18 +++++++-- src/utils.cpp | 69 ++++++++++++++++++++++---------- tests/CMakeLists.txt | 1 + tests/utils-tests.cpp | 91 +++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 155 insertions(+), 24 deletions(-) create mode 100644 tests/utils-tests.cpp diff --git a/include/utils.hpp b/include/utils.hpp index d4792a9..b4c1843 100644 --- a/include/utils.hpp +++ b/include/utils.hpp @@ -2,7 +2,7 @@ // utils.hpp // opencl-language-server // -// Created by Ilya Shoshin (Galarius) on 7/21/21. +// Created by Ilia Shoshin on 7/21/21. // #pragma once @@ -35,12 +35,24 @@ std::shared_ptr CreateDefaultGenerator(); std::shared_ptr CreateDefaultExitHandler(); -void Trim(std::string& s); +// --- String Helpers --- + std::vector SplitString(const std::string& str, const std::string& pattern); -std::string UriToFilePath(const std::string& uri); + +void Trim(std::string& s); + bool EndsWith(const std::string& str, const std::string& suffix); + +// --- File Helpers --- + +std::string UriToFilePath(const std::string& uri, bool unix); + +std::string UriToFilePath(const std::string& uri); + std::optional ReadFileContent(std::string_view fileName); +// --- CRC32 --- + namespace internal { // Generates a lookup table for the checksums of all 8-bit values. std::array GenerateCRCLookupTable(); diff --git a/src/utils.cpp b/src/utils.cpp index 2d6f4e0..b6e6e27 100644 --- a/src/utils.cpp +++ b/src/utils.cpp @@ -2,7 +2,7 @@ // utils.cpp // opencl-language-server // -// Created by Ilya Shoshin (Galarius) on 7/21/21. +// Created by Ilia Shoshin on 7/21/21. // #include "utils.hpp" @@ -20,6 +20,8 @@ namespace ocls::utils { +// --- DefaultGenerator --- + class DefaultGenerator final : public IGenerator { public: @@ -47,6 +49,8 @@ std::shared_ptr CreateDefaultGenerator() return std::make_shared(); } +// --- DefaultExitHandler --- + class DefaultExitHandler final : public IExitHandler { public: @@ -63,48 +67,69 @@ std::shared_ptr CreateDefaultExitHandler() return std::make_shared(); } - -void Trim(std::string& s) -{ - s.erase(s.begin(), std::find_if(s.begin(), s.end(), [](unsigned char ch) { return !std::isspace(ch); })); - s.erase(std::find_if(s.rbegin(), s.rend(), [](unsigned char ch) { return !std::isspace(ch); }).base(), s.end()); -} +// --- String Helpers --- std::vector SplitString(const std::string& str, const std::string& pattern) { + if (pattern.empty()) + { + return {str}; + } std::vector result; const std::regex re(pattern); std::sregex_token_iterator iter(str.begin(), str.end(), re, -1); for (std::sregex_token_iterator end; iter != end; ++iter) + { result.push_back(iter->str()); + } return result; } -std::string UriToFilePath(const std::string& uri) +void Trim(std::string& s) +{ + s.erase(s.begin(), std::find_if(s.begin(), s.end(), [](unsigned char ch) { return !std::isspace(ch); })); + s.erase(std::find_if(s.rbegin(), s.rend(), [](unsigned char ch) { return !std::isspace(ch); }).base(), s.end()); +} + +bool EndsWith(const std::string& str, const std::string& suffix) +{ + return str.size() >= suffix.size() && 0 == str.compare(str.size() - suffix.size(), suffix.size(), suffix); +} + +// --- File Helpers --- + +std::string UriToFilePath(const std::string& uri, bool unix) { const size_t bytesNeeded = 8 + 3 * uri.length() + 1; char* fileName = (char*)malloc(bytesNeeded * sizeof(char)); -#if defined(WIN32) - if (uriUriStringToWindowsFilenameA(uri.c_str(), fileName) != URI_SUCCESS) + if (unix) { - free(fileName); - throw std::runtime_error("Failed to convert URI to Windows filename."); + if (uriUriStringToUnixFilenameA(uri.c_str(), fileName) != URI_SUCCESS) + { + free(fileName); + throw std::runtime_error("Failed to convert URI to Unix filename."); + } } -#else - if (uriUriStringToUnixFilenameA(uri.c_str(), fileName) != URI_SUCCESS) + else { - free(fileName); - throw std::runtime_error("Failed to convert URI to Unix filename."); + if (uriUriStringToWindowsFilenameA(uri.c_str(), fileName) != URI_SUCCESS) + { + free(fileName); + throw std::runtime_error("Failed to convert URI to Windows filename."); + } } -#endif std::string result(fileName); free(fileName); return result; } -bool EndsWith(const std::string& str, const std::string& suffix) +std::string UriToFilePath(const std::string& uri) { - return str.size() >= suffix.size() && 0 == str.compare(str.size() - suffix.size(), suffix.size(), suffix); +#if defined(WIN32) + return UriToFilePath(uri, false); +#else + return UriToFilePath(uri, true); +#endif } std::optional ReadFileContent(std::string_view fileName) @@ -126,6 +151,8 @@ std::optional ReadFileContent(std::string_view fileName) return content; } +// --- CRC32 --- + namespace internal { // Generates a lookup table for the checksums of all 8-bit values. std::array GenerateCRCLookupTable() @@ -139,10 +166,10 @@ std::array GenerateCRCLookupTable() std::uint_fast32_t operator()() noexcept { auto checksum = static_cast(n++); - for (auto i = 0; i < 8; ++i) + { checksum = (checksum >> 1) ^ ((checksum & 0x1u) ? reversed_polynomial : 0); - + } return checksum; } diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index c527f3e..cff36c1 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -11,6 +11,7 @@ set(test_sources diagnostics-parser-tests.cpp diagnostics-tests.cpp lsp-event-handler-tests.cpp + utils-tests.cpp main.cpp ) set(libs GTest::gmock nlohmann_json::nlohmann_json spdlog::spdlog OpenCL::HeadersCpp uriparser::uriparser) diff --git a/tests/utils-tests.cpp b/tests/utils-tests.cpp new file mode 100644 index 0000000..2de0c9e --- /dev/null +++ b/tests/utils-tests.cpp @@ -0,0 +1,91 @@ +// +// utils-tests.cpp +// opencl-language-server-tests +// +// Created by Ilia Shoshin on 20/8/23. +// + +#include "utils.hpp" + +#include + + +using namespace ocls; + + +// --- SplitString --- + + +TEST(UtilsTest, BasicSplit) +{ + std::string str = "apple,banana,orange"; + auto result = utils::SplitString(str, ","); + EXPECT_EQ(result.size(), 3); + EXPECT_EQ(result[0], "apple"); + EXPECT_EQ(result[1], "banana"); + EXPECT_EQ(result[2], "orange"); +} + +TEST(UtilsTest, NoSplitPattern) +{ + std::string str = "applebananaorange"; + auto result = utils::SplitString(str, ","); + EXPECT_EQ(result.size(), 1); + EXPECT_EQ(result[0], "applebananaorange"); +} + +TEST(UtilsTest, EmptyString) +{ + std::string str = ""; + auto result = utils::SplitString(str, ","); + EXPECT_EQ(result.size(), 1); + EXPECT_TRUE(result[0].empty()); +} + +TEST(UtilsTest, EmptyPattern) +{ + std::string str = "applebananaorange"; + auto result = utils::SplitString(str, ""); + EXPECT_EQ(result.size(), 1); + EXPECT_EQ(result[0], "applebananaorange"); +} + +TEST(UtilsTest, SplitAtEveryCharacter) +{ + std::string str = "apple"; + auto result = utils::SplitString(str, "p"); + EXPECT_EQ(result.size(), 3); + EXPECT_EQ(result[0], "a"); + EXPECT_EQ(result[1], ""); + EXPECT_EQ(result[2], "le"); +} + +// --- UriToFilePath --- + +TEST(UriToFilePathTest, BasicUri) +{ + std::string uri = "file:///C:/folder/file.txt"; + auto result = utils::UriToFilePath(uri, false); + EXPECT_EQ(result, "C:\\folder\\file.txt"); +} + +TEST(UriToFilePathTest, UriWithSpaces) +{ + std::string uri = "file:///C:/some%20folder/file%20name.txt"; + auto result = utils::UriToFilePath(uri, false); + EXPECT_EQ(result, "C:\\some folder\\file name.txt"); +} + +TEST(UriToFilePathTest, UnixPath) +{ + std::string uri = "file:///home/user/folder/file.txt"; + auto result = utils::UriToFilePath(uri, true); + EXPECT_EQ(result, "/home/user/folder/file.txt"); +} + +TEST(UriToFilePathTest, EmptyUri) +{ + std::string uri = ""; + auto result = utils::UriToFilePath(uri, true); + EXPECT_EQ(result, ""); // Assuming it returns empty string +} From 4f23ba3fd2b33fc60a2ba1f6b9bbd7eea084a654 Mon Sep 17 00:00:00 2001 From: Galarius Date: Sun, 20 Aug 2023 22:58:43 +0300 Subject: [PATCH 25/32] Refactor logging --- CMakeLists.txt | 2 ++ include/log.hpp | 25 ++++++++++++++ include/utils.hpp | 5 +++ src/clinfo.cpp | 32 +++++++++--------- src/diagnostics.cpp | 31 ++++++++---------- src/jsonrpc.cpp | 77 +++++++++++++++++++++----------------------- src/log.cpp | 67 ++++++++++++++++++++++++++++++++++++++ src/lsp.cpp | 54 +++++++++++++++---------------- src/main.cpp | 45 ++++++-------------------- tests/CMakeLists.txt | 1 + tests/main.cpp | 16 ++------- 11 files changed, 206 insertions(+), 149 deletions(-) create mode 100644 include/log.hpp create mode 100644 src/log.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 88696e6..c74eda8 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -60,6 +60,7 @@ set(headers clinfo.hpp diagnostics.hpp jsonrpc.hpp + log.hpp lsp.hpp utils.hpp ) @@ -67,6 +68,7 @@ set(sources clinfo.cpp diagnostics.cpp jsonrpc.cpp + log.cpp lsp.cpp main.cpp utils.cpp diff --git a/include/log.hpp b/include/log.hpp new file mode 100644 index 0000000..822487c --- /dev/null +++ b/include/log.hpp @@ -0,0 +1,25 @@ +// +// logs.hpp +// opencl-language-server +// +// Created by Ilia Shoshin on 20/08/23. +// + +#include +#include + +namespace ocls { + +struct LogName +{ + static std::string main; + static std::string clinfo; + static std::string diagnostics; + static std::string jrpc; + static std::string lsp; +}; + +extern void ConfigureFileLogging(const std::string& filename, spdlog::level::level_enum level); +extern void ConfigureNullLogging(); + +} // namespace ocls diff --git a/include/utils.hpp b/include/utils.hpp index b4c1843..37f7c4a 100644 --- a/include/utils.hpp +++ b/include/utils.hpp @@ -43,6 +43,11 @@ void Trim(std::string& s); bool EndsWith(const std::string& str, const std::string& suffix); +inline std::string FormatBool(bool flag) +{ + return flag ? "yes" : "no"; +} + // --- File Helpers --- std::string UriToFilePath(const std::string& uri, bool unix); diff --git a/src/clinfo.cpp b/src/clinfo.cpp index c1f3922..2f5b639 100644 --- a/src/clinfo.cpp +++ b/src/clinfo.cpp @@ -6,10 +6,10 @@ // #include "clinfo.hpp" +#include "log.hpp" #include "utils.hpp" #include -#include #include using namespace nlohmann; @@ -18,7 +18,7 @@ using ocls::ICLInfo; namespace { -constexpr const char logger[] = "clinfo"; +auto logger() { return spdlog::get(ocls::LogName::clinfo); } const std::unordered_map booleanChoices { {CL_TRUE, "CL_TRUE"}, @@ -350,7 +350,7 @@ json::object_t GetDeviceJSONInfo(const cl::Device& device) } catch (const cl::Error& err) { - spdlog::get(logger)->error("Failed to get info for the device, {}", err.what()); + logger()->error("Failed to get info for the device, {}", err.what()); continue; } } @@ -373,7 +373,7 @@ uint32_t CalculateDeviceID(const cl::Device& device) } catch (const cl::Error& err) { - spdlog::get(logger)->error("Failed to calculate device uuid, {}", err.what()); + logger()->error("Failed to calculate device uuid, {}", err.what()); } return 0; } @@ -403,7 +403,7 @@ uint32_t CalculatePlatformID(const cl::Platform& platform) } catch (const cl::Error& err) { - spdlog::get(logger)->error("Failed to calculate platform uuid, {}", err.what()); + logger()->error("Failed to calculate platform uuid, {}", err.what()); } return 0; } @@ -421,7 +421,7 @@ json GetPlatformJSONInfo(const cl::Platform& platform) } catch (const cl::Error& err) { - spdlog::get(logger)->error("Failed to get info for a platform, {}", err.what()); + logger()->error("Failed to get info for a platform, {}", err.what()); } } @@ -438,7 +438,7 @@ json GetPlatformJSONInfo(const cl::Platform& platform) } catch (const cl::Error& err) { - spdlog::get(logger)->error("Failed to get devices for a platform, {}", err.what()); + logger()->error("Failed to get devices for a platform, {}", err.what()); } return info; @@ -449,7 +449,7 @@ class CLInfo final : public ICLInfo public: nlohmann::json json() { - spdlog::get(logger)->trace("Searching for OpenCL platforms..."); + logger()->trace("Searching for OpenCL platforms..."); std::vector platforms; try { @@ -457,10 +457,10 @@ class CLInfo final : public ICLInfo } catch (const cl::Error& err) { - spdlog::get(logger)->error("No OpenCL platforms were found ({})", err.what()); + logger()->error("No OpenCL platforms were found ({})", err.what()); } - spdlog::get(logger)->info("Found OpenCL platforms, {}", platforms.size()); + logger()->info("Found OpenCL platforms, {}", platforms.size()); if (platforms.size() == 0) { return {}; @@ -481,11 +481,11 @@ class CLInfo final : public ICLInfo try { cl::Platform::get(&platforms); - spdlog::get(logger)->info("Found OpenCL platforms: {}", platforms.size()); + logger()->info("Found OpenCL platforms: {}", platforms.size()); } catch (cl::Error& err) { - spdlog::get(logger)->error("No OpenCL platforms were found, {}", err.what()); + logger()->error("No OpenCL platforms were found, {}", err.what()); } return platforms; } @@ -500,7 +500,7 @@ class CLInfo final : public ICLInfo { std::vector platformDevices; platform.getDevices(CL_DEVICE_TYPE_ALL, &platformDevices); - spdlog::get(logger)->info("Found OpenCL devices: {}", platformDevices.size()); + logger()->info("Found OpenCL devices: {}", platformDevices.size()); devices.insert( devices.end(), std::make_move_iterator(platformDevices.begin()), @@ -508,7 +508,7 @@ class CLInfo final : public ICLInfo } catch (cl::Error& err) { - spdlog::get(logger)->error("No OpenCL devices were found, {}", err.what()); + logger()->error("No OpenCL devices were found, {}", err.what()); } } return devices; @@ -536,7 +536,7 @@ class CLInfo final : public ICLInfo } catch (cl::Error& err) { - spdlog::get(logger)->error("Failed to get description for the selected device, {}", err.what()); + logger()->error("Failed to get description for the selected device, {}", err.what()); } return "unknown"; } @@ -551,7 +551,7 @@ class CLInfo final : public ICLInfo } catch (const cl::Error& err) { - spdlog::get(logger)->error("Failed to get power index for the device, {}", err.what()); + logger()->error("Failed to get power index for the device, {}", err.what()); } return 0; } diff --git a/src/diagnostics.cpp b/src/diagnostics.cpp index 5efa4b8..d22ecc2 100644 --- a/src/diagnostics.cpp +++ b/src/diagnostics.cpp @@ -6,6 +6,7 @@ // #include "diagnostics.hpp" +#include "log.hpp" #include "utils.hpp" #include @@ -13,7 +14,6 @@ #include #include #include -#include #include // std::runtime_error, std::invalid_argument #include @@ -21,7 +21,7 @@ using namespace nlohmann; namespace { -constexpr char logger[] = "diagnostics"; +auto logger() { return spdlog::get(ocls::LogName::diagnostics); } } // namespace @@ -82,7 +82,7 @@ class DiagnosticsParser final : public IDiagnosticsParser { if (count++ >= problemsLimit) { - spdlog::get(logger)->info("Maximum number of problems reached, other problems will be skipped"); + logger()->info("Maximum number of problems reached, other problems will be skipped"); break; } diagnostics.emplace_back(CreateDiagnostic(matches, name)); @@ -138,26 +138,25 @@ void Diagnostics::SetBuildOptions(const json& options) } catch (std::exception& e) { - spdlog::get(logger)->error("Failed to parse build options, {}", e.what()); + logger()->error("Failed to parse build options, {}", e.what()); } } void Diagnostics::SetBuildOptions(const std::string& options) { m_BuildOptions = options; - spdlog::get(logger)->trace("Set build options, {}", m_BuildOptions); + logger()->trace("Set build options, {}", m_BuildOptions); } void Diagnostics::SetMaxProblemsCount(uint64_t maxNumberOfProblems) { - spdlog::get(logger)->trace("Set max number of problems: {}", maxNumberOfProblems); + logger()->trace("Set max number of problems: {}", maxNumberOfProblems); m_maxNumberOfProblems = maxNumberOfProblems; } void Diagnostics::SetOpenCLDevice(uint32_t identifier) { - auto log = spdlog::get(logger); - log->trace("Selecting OpenCL device..."); + logger()->trace("Selecting OpenCL device..."); const auto devices = m_clInfo->GetDevices(); @@ -169,13 +168,13 @@ void Diagnostics::SetOpenCLDevice(uint32_t identifier) auto selectedDevice = SelectOpenCLDevice(devices, identifier); if (!selectedDevice) { - log->warn("No suitable OpenCL device was found."); + logger()->warn("No suitable OpenCL device was found."); return; } m_device = selectedDevice; auto description = m_clInfo->GetDeviceDescription(*m_device); - log->info("Selected OpenCL device: {}", description); + logger()->info("Selected OpenCL device: {}", description); } std::string Diagnostics::GetBuildLog(const Source& source) @@ -184,7 +183,7 @@ std::string Diagnostics::GetBuildLog(const Source& source) { throw std::runtime_error("missing OpenCL device"); } - spdlog::get(logger)->trace("Getting diagnostics..."); + logger()->trace("Getting diagnostics..."); return BuildSource(source.text); } @@ -198,7 +197,7 @@ nlohmann::json Diagnostics::GetDiagnostics(const Source& source) srcName = std::filesystem::path(filePath).filename().string(); } buildLog = BuildSource(source.text); - spdlog::get(logger)->trace("BuildLog:\n{}", buildLog); + logger()->trace("BuildLog:\n{}", buildLog); return m_parser->ParseDiagnostics(buildLog, srcName, m_maxNumberOfProblems); } @@ -222,8 +221,6 @@ std::optional Diagnostics::SelectOpenCLDeviceByPowerIndex(const std: std::optional Diagnostics::SelectOpenCLDevice(const std::vector& devices, uint32_t identifier) { - auto log = spdlog::get(logger); - if (identifier > 0) { auto it = std::find_if(devices.begin(), devices.end(), [this, &identifier](const cl::Device& device) { @@ -265,7 +262,7 @@ std::string Diagnostics::BuildSource(const std::string& source) const cl::Program program; try { - spdlog::get(logger)->debug("Building program with options: {}", m_BuildOptions); + logger()->debug("Building program with options: {}", m_BuildOptions); program = cl::Program(context, source, false); program.build(ds, m_BuildOptions.c_str()); } @@ -273,7 +270,7 @@ std::string Diagnostics::BuildSource(const std::string& source) const { if (err.err() != CL_BUILD_PROGRAM_FAILURE) { - spdlog::get(logger)->error("Failed to build program: {} ({})", err.what(), err.err()); + logger()->error("Failed to build program: {} ({})", err.what(), err.err()); throw err; } } @@ -286,7 +283,7 @@ std::string Diagnostics::BuildSource(const std::string& source) const } catch (cl::Error& err) { - spdlog::get(logger)->error("Failed get build info, error, {}", err.what()); + logger()->error("Failed get build info, error, {}", err.what()); } return build_log; diff --git a/src/jsonrpc.cpp b/src/jsonrpc.cpp index 15e53d6..641583b 100644 --- a/src/jsonrpc.cpp +++ b/src/jsonrpc.cpp @@ -6,10 +6,11 @@ // #include "jsonrpc.hpp" +#include "log.hpp" +#include "utils.hpp" #include #include -#include #include using namespace nlohmann; @@ -18,13 +19,9 @@ namespace ocls { namespace { -constexpr char logger[] = "jrpc"; -constexpr char LE[] = "\r\n"; +auto logger() { return spdlog::get(ocls::LogName::jrpc); } -inline std::string FormatBool(bool flag) -{ - return flag ? "yes" : "no"; -} +constexpr char LE[] = "\r\n"; } // namespace @@ -97,19 +94,19 @@ class JsonRPC final : public IJsonRPC void JsonRPC::RegisterMethodCallback(const std::string& method, InputCallbackFunc&& func) { - spdlog::get(logger)->trace("Set callback for method: {}", method); + logger()->trace("Set callback for method: {}", method); m_callbacks[method] = std::move(func); } void JsonRPC::RegisterInputCallback(InputCallbackFunc&& func) { - spdlog::get(logger)->trace("Set callback for client responds"); + logger()->trace("Set callback for client responds"); m_respondCallback = std::move(func); } void JsonRPC::RegisterOutputCallback(OutputCallbackFunc&& func) { - spdlog::get(logger)->trace("Set output callback"); + logger()->trace("Set output callback"); m_outputCallback = std::move(func); } @@ -153,7 +150,7 @@ void JsonRPC::Write(const json& data) const } catch (std::exception& err) { - spdlog::get(logger)->error("Failed to write message: '{}', error: {}", message, err.what()); + logger()->error("Failed to write message: '{}', error: {}", message, err.what()); } } @@ -172,15 +169,15 @@ void JsonRPC::WriteTrace(const std::string& message, const std::string& verbose) { if (!m_tracing) { - spdlog::get(logger)->debug("JRPC tracing is disabled"); - spdlog::get(logger)->trace("The message was: '{}', verbose: {}", message, verbose); + logger()->debug("JRPC tracing is disabled"); + logger()->trace("The message was: '{}', verbose: {}", message, verbose); return; } if (!verbose.empty() && !m_verbosity) { - spdlog::get(logger)->debug("JRPC verbose tracing is disabled"); - spdlog::get(logger)->trace("The verbose message was: {}", verbose); + logger()->debug("JRPC verbose tracing is disabled"); + logger()->trace("The verbose message was: {}", verbose); } // clang-format off @@ -273,12 +270,12 @@ void JsonRPC::OnInitialize() m_tracing = traceValue != "off"; m_verbosity = traceValue == "verbose"; m_initialized = true; - spdlog::get(logger)->debug( - "Tracing options: is verbose: {}, is on: {}", FormatBool(m_verbosity), FormatBool(m_tracing)); + logger()->debug( + "Tracing options: is verbose: {}, is on: {}", utils::FormatBool(m_verbosity), utils::FormatBool(m_tracing)); } catch (std::exception& err) { - spdlog::get(logger)->error("Failed to read tracing options, {}", err.what()); + logger()->error("Failed to read tracing options, {}", err.what()); } } @@ -289,12 +286,12 @@ void JsonRPC::OnTracingChanged(const json& data) const auto traceValue = data["params"]["value"].get(); m_tracing = traceValue != "off"; m_verbosity = traceValue == "verbose"; - spdlog::get(logger)->debug( - "Tracing options were changed, is verbose: {}, is on: {}", FormatBool(m_verbosity), FormatBool(m_tracing)); + logger()->debug( + "Tracing options were changed, is verbose: {}, is on: {}", utils::FormatBool(m_verbosity), utils::FormatBool(m_tracing)); } catch (std::exception& err) { - spdlog::get(logger)->error("Failed to read tracing options, {}", err.what()); + logger()->error("Failed to read tracing options, {}", err.what()); } } @@ -322,7 +319,7 @@ void JsonRPC::FireRespondCallback() { if (m_respondCallback) { - spdlog::get(logger)->debug("Calling handler for a client respond"); + logger()->debug("Calling handler for a client respond"); m_respondCallback(m_body); } } @@ -335,8 +332,8 @@ void JsonRPC::FireMethodCallback() { const bool isRequest = m_body["params"]["id"] != nullptr; const bool mustRespond = isRequest || m_method.rfind("$/", 0) == std::string::npos; - spdlog::get(logger)->debug( - "Got request: {}, respond is required: {}", FormatBool(isRequest), FormatBool(mustRespond)); + logger()->debug( + "Got request: {}, respond is required: {}", utils::FormatBool(isRequest), utils::FormatBool(mustRespond)); if (mustRespond) { WriteError(JRPCErrorCode::MethodNotFound, "Method '" + m_method + "' is not supported."); @@ -346,19 +343,19 @@ void JsonRPC::FireMethodCallback() { try { - spdlog::get(logger)->debug("Calling handler for method: '{}'", m_method); + logger()->debug("Calling handler for method: '{}'", m_method); callback->second(m_body); } catch (std::exception& err) { - spdlog::get(logger)->error("Failed to handle method '{}', err: {}", m_method, err.what()); + logger()->error("Failed to handle method '{}', err: {}", m_method, err.what()); } } } void JsonRPC::WriteError(JRPCErrorCode errorCode, const std::string& message) const { - spdlog::get(logger)->trace("Reporting error: '{}' ({})", message, static_cast(errorCode)); + logger()->trace("Reporting error: '{}' ({})", message, static_cast(errorCode)); json obj = { {"error", { @@ -375,15 +372,15 @@ void JsonRPC::LogBufferContent() const return; } - spdlog::get(logger)->debug(""); - spdlog::get(logger)->debug(">>>>>>>>>>>>>>>>"); + logger()->debug(""); + logger()->debug(">>>>>>>>>>>>>>>>"); for (auto& header : m_headers) { - spdlog::get(logger)->debug(header.first, ": ", header.second); + logger()->debug(header.first, ": ", header.second); } - spdlog::get(logger)->debug(m_buffer); - spdlog::get(logger)->debug(">>>>>>>>>>>>>>>>"); - spdlog::get(logger)->debug(""); + logger()->debug(m_buffer); + logger()->debug(">>>>>>>>>>>>>>>>"); + logger()->debug(""); } void JsonRPC::LogMessage(const std::string& message) const @@ -393,23 +390,23 @@ void JsonRPC::LogMessage(const std::string& message) const return; } - spdlog::get(logger)->debug(""); - spdlog::get(logger)->debug("<<<<<<<<<<<<<<<<"); - spdlog::get(logger)->debug(message); - spdlog::get(logger)->debug("<<<<<<<<<<<<<<<<"); - spdlog::get(logger)->debug(""); + logger()->debug(""); + logger()->debug("<<<<<<<<<<<<<<<<"); + logger()->debug(message); + logger()->debug("<<<<<<<<<<<<<<<<"); + logger()->debug(""); } void JsonRPC::LogAndHandleParseError(std::exception& e) { - spdlog::get(logger)->error("Failed to parse request with reason: '{}'\n{}", e.what(), "\n", m_buffer); + logger()->error("Failed to parse request with reason: '{}'\n{}", e.what(), "\n", m_buffer); m_buffer.clear(); WriteError(JRPCErrorCode::ParseError, "Failed to parse request"); } void JsonRPC::LogAndHandleUnexpectedMessage() { - spdlog::get(logger)->error("Unexpected first message: '{}'", m_method); + logger()->error("Unexpected first message: '{}'", m_method); WriteError(JRPCErrorCode::NotInitialized, "Server was not initialized."); } diff --git a/src/log.cpp b/src/log.cpp new file mode 100644 index 0000000..b81c2b2 --- /dev/null +++ b/src/log.cpp @@ -0,0 +1,67 @@ +// +// logs.cpp +// opencl-language-server +// +// Created by Ilia Shoshin on 20/08/23. +// + +#include "log.hpp" + +#include + +#include +#include + +namespace ocls { + +std::string LogName::main = "opencl-ls"; +std::string LogName::clinfo = "clinfo"; +std::string LogName::diagnostics = "diagnostics"; +std::string LogName::jrpc = "jrpc"; +std::string LogName::lsp = "lsp"; + +void ConfigureLogging(bool fileLogging, const std::string& filename, spdlog::level::level_enum level) +{ + try + { + spdlog::sink_ptr sink; + if (fileLogging) + { + sink = std::make_shared(filename); + } + else + { + sink = std::make_shared(); + } + + spdlog::set_default_logger(std::make_shared(LogName::main, sink)); + spdlog::set_level(level); + std::vector> subLoggers = { + std::make_shared(LogName::clinfo, sink), + std::make_shared(LogName::diagnostics, sink), + std::make_shared(LogName::jrpc, sink), + std::make_shared(LogName::lsp, sink)}; + + for (const auto& logger : subLoggers) + { + logger->set_level(level); + spdlog::register_logger(logger); + } + } + catch (const spdlog::spdlog_ex& ex) + { + std::cerr << "Log init failed: " << ex.what() << std::endl; + } +} + +void ConfigureFileLogging(const std::string& filename, spdlog::level::level_enum level) +{ + ConfigureLogging(true, filename, level); +} + +void ConfigureNullLogging() +{ + ConfigureLogging(false, "", spdlog::level::trace); +} + +} // namespace ocls diff --git a/src/lsp.cpp b/src/lsp.cpp index 9442757..724ce2e 100644 --- a/src/lsp.cpp +++ b/src/lsp.cpp @@ -5,13 +5,14 @@ // Created by Ilia Shoshin on 7/16/21. // -#include "lsp.hpp" #include "diagnostics.hpp" #include "jsonrpc.hpp" +#include "log.hpp" +#include "lsp.hpp" +#include "utils.hpp" #include #include -#include #include using namespace nlohmann; @@ -41,7 +42,7 @@ std::optional GetNestedValue(const nlohmann::json &j, const std: namespace ocls { -constexpr char logger[] = "lsp"; +auto logger() { return spdlog::get(ocls::LogName::lsp); } struct Capabilities { @@ -123,11 +124,11 @@ void LSPServerEventsHandler::GetConfiguration() { if (!m_capabilities.hasConfigurationCapability) { - spdlog::get(logger)->debug("Does not have configuration capability"); + logger()->debug("Does not have configuration capability"); return; } - spdlog::get(logger)->debug("Make configuration request"); + logger()->debug("Make configuration request"); json buildOptions = {{"section", "OpenCL.server.buildOptions"}}; json maxNumberOfProblems = {{"section", "OpenCL.server.maxNumberOfProblems"}}; json openCLDeviceID = {{"section", "OpenCL.server.deviceID"}}; @@ -153,10 +154,10 @@ std::optional LSPServerEventsHandler::GetNextResponse() void LSPServerEventsHandler::OnInitialize(const json &data) { - spdlog::get(logger)->debug("Received 'initialize' request"); + logger()->debug("Received 'initialize' request"); if (!data.contains("id")) { - spdlog::get(logger)->error("'initialize' message does not contain 'id'"); + logger()->error("'initialize' message does not contain 'id'"); return; } auto requestId = data["id"]; @@ -210,10 +211,10 @@ void LSPServerEventsHandler::OnInitialize(const json &data) void LSPServerEventsHandler::OnInitialized(const json &data) { - spdlog::get(logger)->debug("Received 'initialized' message"); + logger()->debug("Received 'initialized' message"); if (!data.contains("id")) { - spdlog::get(logger)->error("'initialized' message does not contain 'id'"); + logger()->error("'initialized' message does not contain 'id'"); return; } @@ -221,7 +222,7 @@ void LSPServerEventsHandler::OnInitialized(const json &data) if (!m_capabilities.supportDidChangeConfiguration) { - spdlog::get(logger)->debug("Does not support didChangeConfiguration registration"); + logger()->debug("Does not support didChangeConfiguration registration"); return; } @@ -242,7 +243,7 @@ void LSPServerEventsHandler::BuildDiagnosticsRespond(const std::string &uri, con try { const auto filePath = utils::UriToFilePath(uri); - spdlog::get(logger)->debug("Converted uri '{}' to path '{}'", uri, filePath); + logger()->debug("Converted uri '{}' to path '{}'", uri, filePath); json diags = m_diagnostics->GetDiagnostics({filePath, content}); m_outQueue.push( @@ -256,14 +257,14 @@ void LSPServerEventsHandler::BuildDiagnosticsRespond(const std::string &uri, con catch (std::exception &err) { auto msg = std::string("Failed to get diagnostics: ") + err.what(); - spdlog::get(logger)->error(msg); + logger()->error(msg); m_jrpc->WriteError(JRPCErrorCode::InternalError, msg); } } void LSPServerEventsHandler::OnTextOpen(const json &data) { - spdlog::get(logger)->debug("Received 'textOpen' message"); + logger()->debug("Received 'textOpen' message"); auto uri = GetNestedValue(data, {"params", "textDocument", "uri"}); auto content = GetNestedValue(data, {"params", "textDocument", "text"}); if (uri && content) @@ -274,7 +275,7 @@ void LSPServerEventsHandler::OnTextOpen(const json &data) void LSPServerEventsHandler::OnTextChanged(const json &data) { - spdlog::get(logger)->debug("Received 'textChanged' message"); + logger()->debug("Received 'textChanged' message"); auto uri = GetNestedValue(data, {"params", "textDocument", "uri"}); auto contentChanges = GetNestedValue(data, {"params", "contentChanges"}); if (contentChanges && contentChanges->size() > 0) @@ -292,21 +293,20 @@ void LSPServerEventsHandler::OnTextChanged(const json &data) void LSPServerEventsHandler::OnConfiguration(const json &data) { - auto log = spdlog::get(logger); - log->debug("Received 'configuration' respond"); + logger()->debug("Received 'configuration' respond"); try { auto result = data.at("result"); if (result.empty()) { - log->warn("Empty configuration"); + logger()->warn("Empty configuration"); return; } if (result.size() < NumConfigurations) { - log->warn("Unexpected number of options"); + logger()->warn("Unexpected number of options"); return; } @@ -329,16 +329,16 @@ void LSPServerEventsHandler::OnConfiguration(const json &data) } catch (std::exception &err) { - log->error("Failed to update settings, {}", err.what()); + logger()->error("Failed to update settings, {}", err.what()); } } void LSPServerEventsHandler::OnRespond(const json &data) { - spdlog::get(logger)->debug("Received client respond"); + logger()->debug("Received client respond"); if (m_requests.empty()) { - spdlog::get(logger)->warn("Unexpected respond {}", data.dump()); + logger()->warn("Unexpected respond {}", data.dump()); return; } @@ -352,26 +352,26 @@ void LSPServerEventsHandler::OnRespond(const json &data) } else { - spdlog::get(logger)->warn("Out of order respond"); + logger()->warn("Out of order respond"); } m_requests.pop(); } catch (std::exception &err) { - spdlog::get(logger)->error("OnRespond failed, {}", err.what()); + logger()->error("OnRespond failed, {}", err.what()); } } void LSPServerEventsHandler::OnShutdown(const json &data) { - spdlog::get(logger)->debug("Received 'shutdown' request"); + logger()->debug("Received 'shutdown' request"); m_outQueue.push({{"id", data["id"]}, {"result", nullptr}}); m_shutdown = true; } void LSPServerEventsHandler::OnExit() { - spdlog::get(logger)->debug("Received 'exit', after 'shutdown': {}", m_shutdown ? "yes" : "no"); + logger()->debug("Received 'exit', after 'shutdown': {}", utils::FormatBool(m_shutdown)); if (m_shutdown) { m_exitHandler->OnExit(EXIT_SUCCESS); @@ -386,7 +386,7 @@ void LSPServerEventsHandler::OnExit() int LSPServer::Run() { - spdlog::get(logger)->info("Setting up..."); + logger()->info("Setting up..."); auto self = this->shared_from_this(); // clang-format off // Register handlers for methods @@ -435,7 +435,7 @@ int LSPServer::Run() }); // clang-format on - spdlog::get(logger)->info("Listening..."); + logger()->info("Listening..."); char c; while (std::cin.get(c)) { diff --git a/src/main.cpp b/src/main.cpp index 28db961..a8e32fe 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -10,15 +10,13 @@ #include "clinfo.hpp" #include "diagnostics.hpp" #include "jsonrpc.hpp" +#include "log.hpp" #include "lsp.hpp" #include "version.hpp" #include #include #include -#include -#include -#include #if defined(WIN32) #include @@ -147,38 +145,6 @@ static void SignalHandler(int) } } -void ConfigureLogging(bool fileLogging, std::string filename, spdlog::level::level_enum level) -{ - try - { - spdlog::sink_ptr sink; - if (fileLogging) - { - sink = std::make_shared(filename); - } - else - { - sink = std::make_shared(); - } - spdlog::set_default_logger(std::make_shared("opencl-ls", sink)); - spdlog::set_level(level); - std::vector> subLoggers = { - std::make_shared("clinfo", sink), - std::make_shared("diagnostics", sink), - std::make_shared("jrpc", sink), - std::make_shared("lsp", sink)}; - for (const auto& logger : subLoggers) - { - logger->set_level(level); - spdlog::register_logger(logger); - } - } - catch (const spdlog::spdlog_ex& ex) - { - std::cerr << "Log init failed: " << ex.what() << std::endl; - } -} - inline void SetupBinaryStreamMode() { #if defined(WIN32) @@ -229,7 +195,14 @@ int main(int argc, char* argv[]) CLInfoSubCommand clInfoCmd(app); DiagnosticsSubCommand diagnosticsCmd(app); CLI11_PARSE(app, argc, argv); - ConfigureLogging(flagLogTofile, optLogFile, optLogLevel); + if(flagLogTofile) + { + ConfigureFileLogging(optLogFile, optLogLevel); + } + else + { + ConfigureNullLogging(); + } auto clinfo = CreateCLInfo(); if (clInfoCmd.IsParsed()) diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index cff36c1..800d788 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -2,6 +2,7 @@ set(TESTS_PROJECT_NAME ${PROJECT_NAME}-tests) set(sources diagnostics.cpp jsonrpc.cpp + log.cpp lsp.cpp utils.cpp ) diff --git a/tests/main.cpp b/tests/main.cpp index 1c00956..7aa626b 100644 --- a/tests/main.cpp +++ b/tests/main.cpp @@ -5,24 +5,14 @@ // Created by Ilia Shoshin on 7/16/21. // +#include "log.hpp" + #include #include -#include -#include int main(int argc, char** argv) { + ocls::ConfigureNullLogging(); ::testing::InitGoogleTest(&argc, argv); - auto sink = std::make_shared(); - auto mainLogger = std::make_shared("opencl-language-server", sink); - auto clinfoLogger = std::make_shared("clinfo", sink); - auto diagnosticsLogger = std::make_shared("diagnostics", sink); - auto jsonrpcLogger = std::make_shared("jrpc", sink); - auto lspLogger = std::make_shared("lsp", sink); - spdlog::set_default_logger(mainLogger); - spdlog::register_logger(clinfoLogger); - spdlog::register_logger(diagnosticsLogger); - spdlog::register_logger(jsonrpcLogger); - spdlog::register_logger(lspLogger); return RUN_ALL_TESTS(); } \ No newline at end of file From 5e914ccf5897ec8d67e267cf44f996c3d22aabb1 Mon Sep 17 00:00:00 2001 From: Galarius Date: Mon, 21 Aug 2023 20:53:13 +0300 Subject: [PATCH 26/32] Refactor log messages --- include/utils.hpp | 2 ++ src/clinfo.cpp | 75 +++++++++++++++++++++++++++++---------------- src/diagnostics.cpp | 25 ++++++++++----- src/jsonrpc.cpp | 44 ++++++++++++++------------ src/log.cpp | 4 +++ src/lsp.cpp | 29 +++++++++--------- src/utils.cpp | 10 ++++++ 7 files changed, 120 insertions(+), 69 deletions(-) diff --git a/include/utils.hpp b/include/utils.hpp index 37f7c4a..956b619 100644 --- a/include/utils.hpp +++ b/include/utils.hpp @@ -35,6 +35,8 @@ std::shared_ptr CreateDefaultGenerator(); std::shared_ptr CreateDefaultExitHandler(); +std::string GetCurrentDateTime(); + // --- String Helpers --- std::vector SplitString(const std::string& str, const std::string& pattern); diff --git a/src/clinfo.cpp b/src/clinfo.cpp index 2f5b639..cf5e24e 100644 --- a/src/clinfo.cpp +++ b/src/clinfo.cpp @@ -18,7 +18,10 @@ using ocls::ICLInfo; namespace { -auto logger() { return spdlog::get(ocls::LogName::clinfo); } +auto logger() +{ + return spdlog::get(ocls::LogName::clinfo); +} const std::unordered_map booleanChoices { {CL_TRUE, "CL_TRUE"}, @@ -351,7 +354,6 @@ json::object_t GetDeviceJSONInfo(const cl::Device& device) catch (const cl::Error& err) { logger()->error("Failed to get info for the device, {}", err.what()); - continue; } } return info; @@ -373,7 +375,7 @@ uint32_t CalculateDeviceID(const cl::Device& device) } catch (const cl::Error& err) { - logger()->error("Failed to calculate device uuid, {}", err.what()); + logger()->error("Failed to calculate the device uuid, {}", err.what()); } return 0; } @@ -403,7 +405,7 @@ uint32_t CalculatePlatformID(const cl::Platform& platform) } catch (const cl::Error& err) { - logger()->error("Failed to calculate platform uuid, {}", err.what()); + logger()->error("Failed to calculate the platform uuid, {}", err.what()); } return 0; } @@ -421,7 +423,7 @@ json GetPlatformJSONInfo(const cl::Platform& platform) } catch (const cl::Error& err) { - logger()->error("Failed to get info for a platform, {}", err.what()); + logger()->error("Failed to get information for the platform, {}", err.what()); } } @@ -438,7 +440,7 @@ json GetPlatformJSONInfo(const cl::Platform& platform) } catch (const cl::Error& err) { - logger()->error("Failed to get devices for a platform, {}", err.what()); + logger()->error("Failed to get platform's devices, {}", err.what()); } return info; @@ -449,26 +451,16 @@ class CLInfo final : public ICLInfo public: nlohmann::json json() { - logger()->trace("Searching for OpenCL platforms..."); - std::vector platforms; - try - { - cl::Platform::get(&platforms); - } - catch (const cl::Error& err) - { - logger()->error("No OpenCL platforms were found ({})", err.what()); - } - - logger()->info("Found OpenCL platforms, {}", platforms.size()); + const auto platforms = GetPlatforms(); if (platforms.size() == 0) { return {}; } std::vector jsonPlatforms; - for (auto& platform : platforms) + for (const auto& platform : platforms) { + logger()->trace("{}", GetPlatformDescription(platform)); jsonPlatforms.emplace_back(GetPlatformJSONInfo(platform)); } @@ -477,30 +469,59 @@ class CLInfo final : public ICLInfo std::vector GetPlatforms() { + logger()->trace("Searching for OpenCL platforms..."); std::vector platforms; try { cl::Platform::get(&platforms); - logger()->info("Found OpenCL platforms: {}", platforms.size()); + logger()->trace("Found OpenCL platforms: {}", platforms.size()); } catch (cl::Error& err) { - logger()->error("No OpenCL platforms were found, {}", err.what()); + logger()->error("Failed to find OpenCL platforms, {}", err.what()); } return platforms; } + std::string GetPlatformDescription(const cl::Platform& platform) + { + try + { + auto name = platform.getInfo(); + auto vendor = platform.getInfo(); + auto version = platform.getInfo(); + auto profile = platform.getInfo(); + auto description = "name: " + std::move(name) + "; " + "vendor: " + std::move(vendor) + "; " + + "version: " + std::move(version) + "; " + "profile: " + std::move(profile); + return description; + } + catch (cl::Error& err) + { + logger()->error("Failed to get the platform's description, {}", err.what()); + } + return "unknown"; + } + std::vector GetDevices() { std::vector devices; - const auto platforms = GetPlatforms(); - for (const auto& platform : platforms) + auto platforms = GetPlatforms(); + for (auto& platform : platforms) { + logger()->trace("Platform {}", GetPlatformDescription(platform)); + logger()->trace("Searching for platform's devices..."); try { std::vector platformDevices; platform.getDevices(CL_DEVICE_TYPE_ALL, &platformDevices); - logger()->info("Found OpenCL devices: {}", platformDevices.size()); + if (logger()->level() <= spdlog::level::trace) + { + logger()->trace("Found OpenCL devices: {}", platformDevices.size()); + for (auto& device : platformDevices) + { + logger()->trace("Device {}", GetDeviceDescription(device)); + } + } devices.insert( devices.end(), std::make_move_iterator(platformDevices.begin()), @@ -508,7 +529,7 @@ class CLInfo final : public ICLInfo } catch (cl::Error& err) { - logger()->error("No OpenCL devices were found, {}", err.what()); + logger()->error("Failed to find the platform's devices, {}", err.what()); } } return devices; @@ -536,7 +557,7 @@ class CLInfo final : public ICLInfo } catch (cl::Error& err) { - logger()->error("Failed to get description for the selected device, {}", err.what()); + logger()->error("Failed to get the device's description, {}", err.what()); } return "unknown"; } @@ -551,7 +572,7 @@ class CLInfo final : public ICLInfo } catch (const cl::Error& err) { - logger()->error("Failed to get power index for the device, {}", err.what()); + logger()->error("Failed to get the device's power index, {}", err.what()); } return 0; } diff --git a/src/diagnostics.cpp b/src/diagnostics.cpp index d22ecc2..dac7381 100644 --- a/src/diagnostics.cpp +++ b/src/diagnostics.cpp @@ -82,7 +82,7 @@ class DiagnosticsParser final : public IDiagnosticsParser { if (count++ >= problemsLimit) { - logger()->info("Maximum number of problems reached, other problems will be skipped"); + logger()->warn("Maximum number of problems reached, other problems will be skipped"); break; } diagnostics.emplace_back(CreateDiagnostic(matches, name)); @@ -145,7 +145,7 @@ void Diagnostics::SetBuildOptions(const json& options) void Diagnostics::SetBuildOptions(const std::string& options) { m_BuildOptions = options; - logger()->trace("Set build options, {}", m_BuildOptions); + logger()->trace("Set build options: {}", m_BuildOptions); } void Diagnostics::SetMaxProblemsCount(uint64_t maxNumberOfProblems) @@ -156,7 +156,7 @@ void Diagnostics::SetMaxProblemsCount(uint64_t maxNumberOfProblems) void Diagnostics::SetOpenCLDevice(uint32_t identifier) { - logger()->trace("Selecting OpenCL device..."); + logger()->trace("Selecting OpenCL device [{}]...", identifier); const auto devices = m_clInfo->GetDevices(); @@ -168,13 +168,13 @@ void Diagnostics::SetOpenCLDevice(uint32_t identifier) auto selectedDevice = SelectOpenCLDevice(devices, identifier); if (!selectedDevice) { - logger()->warn("No suitable OpenCL device was found."); + logger()->warn("No suitable OpenCL device was found"); return; } m_device = selectedDevice; auto description = m_clInfo->GetDeviceDescription(*m_device); - logger()->info("Selected OpenCL device: {}", description); + logger()->debug("Selected OpenCL device: {}", description); } std::string Diagnostics::GetBuildLog(const Source& source) @@ -223,6 +223,7 @@ std::optional Diagnostics::SelectOpenCLDevice(const std::vector 0) { + logger()->trace("Searching for the device by ID '{}'...", identifier); auto it = std::find_if(devices.begin(), devices.end(), [this, &identifier](const cl::Device& device) { try { @@ -241,6 +242,7 @@ std::optional Diagnostics::SelectOpenCLDevice(const std::vectortrace("Searching for the device by power index..."); auto device = SelectOpenCLDeviceByPowerIndex(devices); if (device && (!m_device || m_clInfo->GetDevicePowerIndex(*device) > m_clInfo->GetDevicePowerIndex(*m_device))) { @@ -262,7 +264,14 @@ std::string Diagnostics::BuildSource(const std::string& source) const cl::Program program; try { - logger()->debug("Building program with options: {}", m_BuildOptions); + if(m_BuildOptions.empty()) + { + logger()->trace("Building program..."); + } + else + { + logger()->trace("Building program with options: {}...", m_BuildOptions); + } program = cl::Program(context, source, false); program.build(ds, m_BuildOptions.c_str()); } @@ -270,7 +279,7 @@ std::string Diagnostics::BuildSource(const std::string& source) const { if (err.err() != CL_BUILD_PROGRAM_FAILURE) { - logger()->error("Failed to build program: {} ({})", err.what(), err.err()); + logger()->error("Failed to build the program: {} ({})", err.what(), err.err()); throw err; } } @@ -283,7 +292,7 @@ std::string Diagnostics::BuildSource(const std::string& source) const } catch (cl::Error& err) { - logger()->error("Failed get build info, error, {}", err.what()); + logger()->error("Failed get build info, {}", err.what()); } return build_log; diff --git a/src/jsonrpc.cpp b/src/jsonrpc.cpp index 641583b..cb166e2 100644 --- a/src/jsonrpc.cpp +++ b/src/jsonrpc.cpp @@ -11,6 +11,7 @@ #include #include +#include #include using namespace nlohmann; @@ -270,8 +271,9 @@ void JsonRPC::OnInitialize() m_tracing = traceValue != "off"; m_verbosity = traceValue == "verbose"; m_initialized = true; - logger()->debug( - "Tracing options: is verbose: {}, is on: {}", utils::FormatBool(m_verbosity), utils::FormatBool(m_tracing)); + logger()->trace("Tracing options: is verbose: {}, is on: {}", + utils::FormatBool(m_verbosity), + utils::FormatBool(m_tracing)); } catch (std::exception& err) { @@ -286,8 +288,9 @@ void JsonRPC::OnTracingChanged(const json& data) const auto traceValue = data["params"]["value"].get(); m_tracing = traceValue != "off"; m_verbosity = traceValue == "verbose"; - logger()->debug( - "Tracing options were changed, is verbose: {}, is on: {}", utils::FormatBool(m_verbosity), utils::FormatBool(m_tracing)); + logger()->trace("Tracing options were changed, is verbose: {}, is on: {}", + utils::FormatBool(m_verbosity), + utils::FormatBool(m_tracing)); } catch (std::exception& err) { @@ -319,7 +322,7 @@ void JsonRPC::FireRespondCallback() { if (m_respondCallback) { - logger()->debug("Calling handler for a client respond"); + logger()->trace("Calling handler for a client respond"); m_respondCallback(m_body); } } @@ -332,8 +335,9 @@ void JsonRPC::FireMethodCallback() { const bool isRequest = m_body["params"]["id"] != nullptr; const bool mustRespond = isRequest || m_method.rfind("$/", 0) == std::string::npos; - logger()->debug( - "Got request: {}, respond is required: {}", utils::FormatBool(isRequest), utils::FormatBool(mustRespond)); + logger()->trace("Got request: {}, respond is required: {}", + utils::FormatBool(isRequest), + utils::FormatBool(mustRespond)); if (mustRespond) { WriteError(JRPCErrorCode::MethodNotFound, "Method '" + m_method + "' is not supported."); @@ -343,7 +347,7 @@ void JsonRPC::FireMethodCallback() { try { - logger()->debug("Calling handler for method: '{}'", m_method); + logger()->trace("Calling handler for method: '{}'", m_method); callback->second(m_body); } catch (std::exception& err) @@ -372,15 +376,16 @@ void JsonRPC::LogBufferContent() const return; } - logger()->debug(""); - logger()->debug(">>>>>>>>>>>>>>>>"); + std::stringstream ss; + ss << "\n>>>>>>>>>>>>>>>>\n"; for (auto& header : m_headers) { - logger()->debug(header.first, ": ", header.second); + ss << header.first << ": " << header.second; } - logger()->debug(m_buffer); - logger()->debug(">>>>>>>>>>>>>>>>"); - logger()->debug(""); + ss << m_buffer; + ss << "\n>>>>>>>>>>>>>>>>\n"; + + logger()->debug(ss.str()); } void JsonRPC::LogMessage(const std::string& message) const @@ -389,12 +394,13 @@ void JsonRPC::LogMessage(const std::string& message) const { return; } + + std::stringstream ss; + ss << "\n<<<<<<<<<<<<<<<<\n" + << message + << "\n<<<<<<<<<<<<<<<<\n"; - logger()->debug(""); - logger()->debug("<<<<<<<<<<<<<<<<"); - logger()->debug(message); - logger()->debug("<<<<<<<<<<<<<<<<"); - logger()->debug(""); + logger()->debug(ss.str()); } void JsonRPC::LogAndHandleParseError(std::exception& e) diff --git a/src/log.cpp b/src/log.cpp index b81c2b2..d822f6b 100644 --- a/src/log.cpp +++ b/src/log.cpp @@ -6,9 +6,11 @@ // #include "log.hpp" +#include "utils.hpp" #include +#include "spdlog/pattern_formatter.h" #include #include @@ -47,6 +49,8 @@ void ConfigureLogging(bool fileLogging, const std::string& filename, spdlog::lev logger->set_level(level); spdlog::register_logger(logger); } + + spdlog::get(LogName::main)->info("\nStarting new session at {}...\n", utils::GetCurrentDateTime()); } catch (const spdlog::spdlog_ex& ex) { diff --git a/src/lsp.cpp b/src/lsp.cpp index 724ce2e..49a8ffd 100644 --- a/src/lsp.cpp +++ b/src/lsp.cpp @@ -124,11 +124,11 @@ void LSPServerEventsHandler::GetConfiguration() { if (!m_capabilities.hasConfigurationCapability) { - logger()->debug("Does not have configuration capability"); + logger()->trace("Does not have configuration capability"); return; } - logger()->debug("Make configuration request"); + logger()->trace("Make configuration request"); json buildOptions = {{"section", "OpenCL.server.buildOptions"}}; json maxNumberOfProblems = {{"section", "OpenCL.server.maxNumberOfProblems"}}; json openCLDeviceID = {{"section", "OpenCL.server.deviceID"}}; @@ -154,7 +154,7 @@ std::optional LSPServerEventsHandler::GetNextResponse() void LSPServerEventsHandler::OnInitialize(const json &data) { - logger()->debug("Received 'initialize' request"); + logger()->trace("Received 'initialize' request"); if (!data.contains("id")) { logger()->error("'initialize' message does not contain 'id'"); @@ -211,7 +211,7 @@ void LSPServerEventsHandler::OnInitialize(const json &data) void LSPServerEventsHandler::OnInitialized(const json &data) { - logger()->debug("Received 'initialized' message"); + logger()->trace("Received 'initialized' message"); if (!data.contains("id")) { logger()->error("'initialized' message does not contain 'id'"); @@ -222,7 +222,7 @@ void LSPServerEventsHandler::OnInitialized(const json &data) if (!m_capabilities.supportDidChangeConfiguration) { - logger()->debug("Does not support didChangeConfiguration registration"); + logger()->trace("Does not support didChangeConfiguration registration"); return; } @@ -243,8 +243,7 @@ void LSPServerEventsHandler::BuildDiagnosticsRespond(const std::string &uri, con try { const auto filePath = utils::UriToFilePath(uri); - logger()->debug("Converted uri '{}' to path '{}'", uri, filePath); - + logger()->trace("'{}' -> '{}'", uri, filePath); json diags = m_diagnostics->GetDiagnostics({filePath, content}); m_outQueue.push( {{"method", "textDocument/publishDiagnostics"}, @@ -264,7 +263,7 @@ void LSPServerEventsHandler::BuildDiagnosticsRespond(const std::string &uri, con void LSPServerEventsHandler::OnTextOpen(const json &data) { - logger()->debug("Received 'textOpen' message"); + logger()->trace("Received 'textOpen' message"); auto uri = GetNestedValue(data, {"params", "textDocument", "uri"}); auto content = GetNestedValue(data, {"params", "textDocument", "text"}); if (uri && content) @@ -275,7 +274,7 @@ void LSPServerEventsHandler::OnTextOpen(const json &data) void LSPServerEventsHandler::OnTextChanged(const json &data) { - logger()->debug("Received 'textChanged' message"); + logger()->trace("Received 'textChanged' message"); auto uri = GetNestedValue(data, {"params", "textDocument", "uri"}); auto contentChanges = GetNestedValue(data, {"params", "contentChanges"}); if (contentChanges && contentChanges->size() > 0) @@ -293,7 +292,7 @@ void LSPServerEventsHandler::OnTextChanged(const json &data) void LSPServerEventsHandler::OnConfiguration(const json &data) { - logger()->debug("Received 'configuration' respond"); + logger()->trace("Received 'configuration' respond"); try { @@ -335,7 +334,7 @@ void LSPServerEventsHandler::OnConfiguration(const json &data) void LSPServerEventsHandler::OnRespond(const json &data) { - logger()->debug("Received client respond"); + logger()->trace("Received client respond"); if (m_requests.empty()) { logger()->warn("Unexpected respond {}", data.dump()); @@ -364,14 +363,14 @@ void LSPServerEventsHandler::OnRespond(const json &data) void LSPServerEventsHandler::OnShutdown(const json &data) { - logger()->debug("Received 'shutdown' request"); + logger()->trace("Received 'shutdown' request"); m_outQueue.push({{"id", data["id"]}, {"result", nullptr}}); m_shutdown = true; } void LSPServerEventsHandler::OnExit() { - logger()->debug("Received 'exit', after 'shutdown': {}", utils::FormatBool(m_shutdown)); + logger()->trace("Received 'exit', after 'shutdown': {}", utils::FormatBool(m_shutdown)); if (m_shutdown) { m_exitHandler->OnExit(EXIT_SUCCESS); @@ -386,7 +385,7 @@ void LSPServerEventsHandler::OnExit() int LSPServer::Run() { - logger()->info("Setting up..."); + logger()->trace("Setting up..."); auto self = this->shared_from_this(); // clang-format off // Register handlers for methods @@ -435,7 +434,7 @@ int LSPServer::Run() }); // clang-format on - logger()->info("Listening..."); + logger()->trace("Listening..."); char c; while (std::cin.get(c)) { diff --git a/src/utils.cpp b/src/utils.cpp index b6e6e27..e81b07d 100644 --- a/src/utils.cpp +++ b/src/utils.cpp @@ -8,6 +8,7 @@ #include "utils.hpp" #include +#include #include #include #include @@ -67,6 +68,15 @@ std::shared_ptr CreateDefaultExitHandler() return std::make_shared(); } +std::string GetCurrentDateTime() +{ + auto now = std::chrono::system_clock::now(); + auto itt = std::chrono::system_clock::to_time_t(now); + std::stringstream ss; + ss << std::put_time(gmtime(&itt), "%Y-%m-%d %H:%M:%S"); + return ss.str(); +} + // --- String Helpers --- std::vector SplitString(const std::string& str, const std::string& pattern) From 956e04c913f4a9e8cfb12dd956865df2e8232f3d Mon Sep 17 00:00:00 2001 From: Galarius Date: Tue, 22 Aug 2023 23:22:52 +0300 Subject: [PATCH 27/32] Apply multiple fixes * Fix `diagnostics` to be empty instead of `null` * Remove duplicate `BuildSource` call * Remove `id` validation in `OnInitialized` handler * Add `--stdio` flag, since VS Code extension adds it * Update tests --- src/diagnostics.cpp | 5 ++--- src/lsp.cpp | 11 ++--------- src/main.cpp | 2 ++ tests/diagnostics-parser-tests.cpp | 4 ++-- tests/diagnostics-tests.cpp | 2 +- tests/lsp-event-handler-tests.cpp | 6 +++--- 6 files changed, 12 insertions(+), 18 deletions(-) diff --git a/src/diagnostics.cpp b/src/diagnostics.cpp index dac7381..e9679bb 100644 --- a/src/diagnostics.cpp +++ b/src/diagnostics.cpp @@ -71,7 +71,7 @@ class DiagnosticsParser final : public IDiagnosticsParser nlohmann::json ParseDiagnostics(const std::string& buildLog, const std::string& name, uint64_t problemsLimit) { - nlohmann::json diagnostics; + nlohmann::json diagnostics = nlohmann::json::array(); std::istringstream stream(buildLog); std::string errLine; uint64_t count = 0; @@ -189,14 +189,13 @@ std::string Diagnostics::GetBuildLog(const Source& source) nlohmann::json Diagnostics::GetDiagnostics(const Source& source) { - std::string buildLog = GetBuildLog(source); std::string srcName; if (!source.filePath.empty()) { auto filePath = std::filesystem::path(source.filePath).string(); srcName = std::filesystem::path(filePath).filename().string(); } - buildLog = BuildSource(source.text); + std::string buildLog = GetBuildLog(source); logger()->trace("BuildLog:\n{}", buildLog); return m_parser->ParseDiagnostics(buildLog, srcName, m_maxNumberOfProblems); } diff --git a/src/lsp.cpp b/src/lsp.cpp index 49a8ffd..cbd2001 100644 --- a/src/lsp.cpp +++ b/src/lsp.cpp @@ -209,16 +209,9 @@ void LSPServerEventsHandler::OnInitialize(const json &data) m_outQueue.push({{"id", requestId}, {"result", {{"capabilities", capabilities}}}}); } -void LSPServerEventsHandler::OnInitialized(const json &data) +void LSPServerEventsHandler::OnInitialized(const json &) { logger()->trace("Received 'initialized' message"); - if (!data.contains("id")) - { - logger()->error("'initialized' message does not contain 'id'"); - return; - } - - auto requestId = data["id"]; if (!m_capabilities.supportDidChangeConfiguration) { @@ -235,7 +228,7 @@ void LSPServerEventsHandler::OnInitialized(const json &data) {"registrations", registrations}, }; - m_outQueue.push({{"id", requestId}, {"method", "client/registerCapability"}, {"params", params}}); + m_outQueue.push({{"id", m_generator->GenerateID()}, {"method", "client/registerCapability"}, {"params", params}}); } void LSPServerEventsHandler::BuildDiagnosticsRespond(const std::string &uri, const std::string &content) diff --git a/src/main.cpp b/src/main.cpp index a8e32fe..fa9226f 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -165,6 +165,7 @@ inline void SetupBinaryStreamMode() int main(int argc, char* argv[]) { bool flagLogTofile = false; + bool flagStdioMode = true; std::string optLogFile = "opencl-language-server.log"; spdlog::level::level_enum optLogLevel = spdlog::level::trace; @@ -184,6 +185,7 @@ int main(int argc, char* argv[]) spdlog::level::critical})) ->required(false) ->capture_default_str(); + app.add_flag("--stdio", flagStdioMode, "Use stdio transport channel for the language server"); app.add_flag_callback( "-v,--version", []() { diff --git a/tests/diagnostics-parser-tests.cpp b/tests/diagnostics-parser-tests.cpp index 8c83177..7d5f592 100644 --- a/tests/diagnostics-parser-tests.cpp +++ b/tests/diagnostics-parser-tests.cpp @@ -73,7 +73,7 @@ TEST(ParseDiagnosticsTest, NoDiagnosticMessages) std::string log = "This is a regular log with no diagnostic message."; auto parser = CreateDiagnosticsParser(); auto result = parser->ParseDiagnostics(log, "TestName", 10); - EXPECT_TRUE(result.is_null()); + EXPECT_EQ(result.size(), 0); } TEST(ParseDiagnosticsTest, SingleDiagnosticMessage) @@ -189,5 +189,5 @@ TEST(ParseDiagnosticsTest, MalformedDiagnosticMessage) std::string log = ":5:14: reference missing for citation"; auto parser = CreateDiagnosticsParser(); auto result = parser->ParseDiagnostics(log, "TestName", 10); - EXPECT_TRUE(result.is_null()); + EXPECT_EQ(result.size(), 0); } diff --git a/tests/diagnostics-tests.cpp b/tests/diagnostics-tests.cpp index d9f5fb8..57bf4da 100644 --- a/tests/diagnostics-tests.cpp +++ b/tests/diagnostics-tests.cpp @@ -77,5 +77,5 @@ TEST(DiagnosticsTest, SelectDeviceBasedOnNonExistingIndex) EXPECT_CALL(*mockCLInfo, GetDeviceDescription(_)).WillOnce(Return("")); auto diagnostics = CreateDiagnostics(mockCLInfo); - diagnostics->SetOpenCLDevice(4527288514); + diagnostics->SetOpenCLDevice(static_cast(4527288514)); } diff --git a/tests/lsp-event-handler-tests.cpp b/tests/lsp-event-handler-tests.cpp index e03af31..3383ae1 100644 --- a/tests/lsp-event-handler-tests.cpp +++ b/tests/lsp-event-handler-tests.cpp @@ -203,7 +203,7 @@ TEST_F(LSPTest, OnInitialized_withDidChangeConfigurationSupport_shouldBuildRespo })"_json; nlohmann::json expectedResponse = R"({ - "id": "1", + "id": "12345678", "method": "client/registerCapability", "params": { "registrations": [{ @@ -216,9 +216,9 @@ TEST_F(LSPTest, OnInitialized_withDidChangeConfigurationSupport_shouldBuildRespo handler->OnInitialize(initData); handler->GetNextResponse(); - EXPECT_CALL(*mockGenerator, GenerateID()).Times(1); + EXPECT_CALL(*mockGenerator, GenerateID()).Times(2); - handler->OnInitialized(R"({"id": "1"})"_json); + handler->OnInitialized({}); auto response = handler->GetNextResponse(); EXPECT_TRUE(response.has_value()); From 5b9b4a7a64365f943956b081bfb54b241cc254a6 Mon Sep 17 00:00:00 2001 From: Galarius Date: Wed, 23 Aug 2023 22:33:01 +0300 Subject: [PATCH 28/32] Fix build --- include/utils.hpp | 5 +++-- src/utils.cpp | 17 +++++++++++++---- 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/include/utils.hpp b/include/utils.hpp index 956b619..4b3b7c6 100644 --- a/include/utils.hpp +++ b/include/utils.hpp @@ -10,6 +10,7 @@ #include #include #include +#include #include #include #include @@ -52,11 +53,11 @@ inline std::string FormatBool(bool flag) // --- File Helpers --- -std::string UriToFilePath(const std::string& uri, bool unix); +std::string UriToFilePath(const std::string& uri, bool isUnix); std::string UriToFilePath(const std::string& uri); -std::optional ReadFileContent(std::string_view fileName); +std::optional ReadFileContent(const std::string& fileName); // --- CRC32 --- diff --git a/src/utils.cpp b/src/utils.cpp index e81b07d..7301340 100644 --- a/src/utils.cpp +++ b/src/utils.cpp @@ -72,11 +72,20 @@ std::string GetCurrentDateTime() { auto now = std::chrono::system_clock::now(); auto itt = std::chrono::system_clock::to_time_t(now); + std::tm tm = {}; + +#if defined(WIN32) + gmtime_s(&tm, &itt); +#else + tm = *gmtime(&itt); +#endif + std::stringstream ss; - ss << std::put_time(gmtime(&itt), "%Y-%m-%d %H:%M:%S"); + ss << std::put_time(&tm, "%Y-%m-%d %H:%M:%S"); return ss.str(); } + // --- String Helpers --- std::vector SplitString(const std::string& str, const std::string& pattern) @@ -108,11 +117,11 @@ bool EndsWith(const std::string& str, const std::string& suffix) // --- File Helpers --- -std::string UriToFilePath(const std::string& uri, bool unix) +std::string UriToFilePath(const std::string& uri, bool isUnix) { const size_t bytesNeeded = 8 + 3 * uri.length() + 1; char* fileName = (char*)malloc(bytesNeeded * sizeof(char)); - if (unix) + if (isUnix) { if (uriUriStringToUnixFilenameA(uri.c_str(), fileName) != URI_SUCCESS) { @@ -142,7 +151,7 @@ std::string UriToFilePath(const std::string& uri) #endif } -std::optional ReadFileContent(std::string_view fileName) +std::optional ReadFileContent(const std::string& fileName) { std::string content; std::ifstream file(fileName); From 10cd9529ae9787820003f7b7f3de0b0206561532 Mon Sep 17 00:00:00 2001 From: Galarius Date: Fri, 25 Aug 2023 19:18:04 +0300 Subject: [PATCH 29/32] Add a wrapper class for cl devices for better testability and performance --- CMakeLists.txt | 1 + include/clinfo.hpp | 10 +- include/device.hpp | 65 +++++++++++ include/diagnostics.hpp | 2 + src/clinfo.cpp | 183 ++++++++++++++++--------------- src/diagnostics.cpp | 46 ++++---- tests/CMakeLists.txt | 3 + tests/diagnostics-tests.cpp | 88 +++++++-------- tests/mocks/clinfo-mock.hpp | 8 +- tests/mocks/diagnostics-mock.hpp | 4 +- 10 files changed, 243 insertions(+), 167 deletions(-) create mode 100644 include/device.hpp diff --git a/CMakeLists.txt b/CMakeLists.txt index c74eda8..41edff5 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -58,6 +58,7 @@ endif() set(headers clinfo.hpp + device.hpp diagnostics.hpp jsonrpc.hpp log.hpp diff --git a/include/clinfo.hpp b/include/clinfo.hpp index 67bb2d9..e8a2fbe 100644 --- a/include/clinfo.hpp +++ b/include/clinfo.hpp @@ -7,7 +7,7 @@ #pragma once -#include +#include "device.hpp" #include #include @@ -21,13 +21,7 @@ struct ICLInfo virtual nlohmann::json json() = 0; - virtual std::vector GetDevices() = 0; - - virtual uint32_t GetDeviceID(const cl::Device& device) = 0; - - virtual std::string GetDeviceDescription(const cl::Device& device) = 0; - - virtual size_t GetDevicePowerIndex(const cl::Device& device) = 0; + virtual std::vector GetDevices() = 0; }; std::shared_ptr CreateCLInfo(); diff --git a/include/device.hpp b/include/device.hpp new file mode 100644 index 0000000..86cab81 --- /dev/null +++ b/include/device.hpp @@ -0,0 +1,65 @@ +// +// device.hpp +// opencl-language-server +// +// Created by Ilia Shoshin on 24.8.2023. +// + +#pragma once + +#include + +namespace ocls { + +class Device +{ +public: + + Device(const Device&) = default; + Device& operator=(const Device&) = default; + Device(Device&&) noexcept = default; + Device& operator=(Device&&) noexcept = default; + +#ifdef ENABLE_TESTING + Device(uint32_t identifier, std::string description, size_t powerIndex) + : m_identifier {identifier} + , m_description {std::move(description)} + , m_powerIndex {powerIndex} + {} +#endif + + Device(cl::Device device, uint32_t identifier, std::string description, size_t powerIndex) + : m_device {std::move(device)} + , m_identifier {identifier} + , m_description {std::move(description)} + , m_powerIndex {powerIndex} + {} + + const cl::Device& getUnderlyingDevice() const noexcept + { + return m_device; + } + + uint32_t GetID() const noexcept + { + return m_identifier; + } + + std::string GetDescription() const noexcept + { + return m_description; + } + + size_t GetPowerIndex() const noexcept + { + return m_powerIndex; + } + +private: + cl::Device m_device; + uint32_t m_identifier; + std::string m_description; + size_t m_powerIndex; +}; + +} // namespace ocls diff --git a/include/diagnostics.hpp b/include/diagnostics.hpp index 3a39d9e..d043009 100644 --- a/include/diagnostics.hpp +++ b/include/diagnostics.hpp @@ -14,6 +14,7 @@ #include #include #include +#include namespace ocls { @@ -46,6 +47,7 @@ struct IDiagnostics virtual void SetMaxProblemsCount(uint64_t maxNumberOfProblems) = 0; virtual void SetOpenCLDevice(uint32_t identifier) = 0; + virtual std::optional GetDevice() const = 0; virtual std::string GetBuildLog(const Source& source) = 0; virtual nlohmann::json GetDiagnostics(const Source& source) = 0; }; diff --git a/src/clinfo.cpp b/src/clinfo.cpp index cf5e24e..5510825 100644 --- a/src/clinfo.cpp +++ b/src/clinfo.cpp @@ -10,6 +10,7 @@ #include "utils.hpp" #include +#include #include using namespace nlohmann; @@ -446,6 +447,80 @@ json GetPlatformJSONInfo(const cl::Platform& platform) return info; } +std::vector GetPlatforms() +{ + logger()->trace("Searching for OpenCL platforms..."); + std::vector platforms; + try + { + cl::Platform::get(&platforms); + logger()->trace("Found OpenCL platforms: {}", platforms.size()); + } + catch (cl::Error& err) + { + logger()->error("Failed to find OpenCL platforms, {}", err.what()); + } + return platforms; +} + +std::string GetPlatformDescription(const cl::Platform& platform) +{ + try + { + auto name = platform.getInfo(); + auto vendor = platform.getInfo(); + auto version = platform.getInfo(); + auto profile = platform.getInfo(); + auto description = "name: " + std::move(name) + "; " + "vendor: " + std::move(vendor) + "; " + + "version: " + std::move(version) + "; " + "profile: " + std::move(profile); + return description; + } + catch (cl::Error& err) + { + logger()->error("Failed to get the platform's description, {}", err.what()); + } + return "unknown"; +} + +size_t GetDevicePowerIndex(const cl::Device& device) +{ + try + { + const size_t maxComputeUnits = device.getInfo(); + const size_t maxClockFrequency = device.getInfo(); + return maxComputeUnits * maxClockFrequency; + } + catch (const cl::Error& err) + { + logger()->error("Failed to get the device's power index, {}", err.what()); + } + return 0; +} + +std::string GetDeviceDescription(const cl::Device& device) +{ + try + { + auto name = device.getInfo(); + auto type = device.getInfo(); + auto version = device.getInfo(); + auto vendor = device.getInfo(); + auto vendorID = device.getInfo(); + auto driverVersion = device.getInfo(); + auto description = "name: " + std::move(name) + "; " + "type: " + std::to_string(type) + "; " + + "version: " + std::move(version) + "; " + "vendor: " + std::move(vendor) + "; " + + "vendorID: " + std::to_string(vendorID) + "; " + "driverVersion: " + std::move(driverVersion); + return description; + } + catch (cl::Error& err) + { + logger()->error("Failed to get the device's description, {}", err.what()); + } + return "unknown"; +} + +// --- CLInfo --- + class CLInfo final : public ICLInfo { public: @@ -467,115 +542,47 @@ class CLInfo final : public ICLInfo return nlohmann::json {{"PLATFORMS", jsonPlatforms}}; } - std::vector GetPlatforms() - { - logger()->trace("Searching for OpenCL platforms..."); - std::vector platforms; - try - { - cl::Platform::get(&platforms); - logger()->trace("Found OpenCL platforms: {}", platforms.size()); - } - catch (cl::Error& err) - { - logger()->error("Failed to find OpenCL platforms, {}", err.what()); - } - return platforms; - } - std::string GetPlatformDescription(const cl::Platform& platform) + std::vector GetDevices() { - try - { - auto name = platform.getInfo(); - auto vendor = platform.getInfo(); - auto version = platform.getInfo(); - auto profile = platform.getInfo(); - auto description = "name: " + std::move(name) + "; " + "vendor: " + std::move(vendor) + "; " + - "version: " + std::move(version) + "; " + "profile: " + std::move(profile); - return description; - } - catch (cl::Error& err) - { - logger()->error("Failed to get the platform's description, {}", err.what()); - } - return "unknown"; - } - - std::vector GetDevices() - { - std::vector devices; - auto platforms = GetPlatforms(); - for (auto& platform : platforms) + std::vector devices; + const auto platforms = GetPlatforms(); + for (const auto& platform : platforms) { logger()->trace("Platform {}", GetPlatformDescription(platform)); logger()->trace("Searching for platform's devices..."); + try { std::vector platformDevices; platform.getDevices(CL_DEVICE_TYPE_ALL, &platformDevices); + if (logger()->level() <= spdlog::level::trace) { - logger()->trace("Found OpenCL devices: {}", platformDevices.size()); - for (auto& device : platformDevices) + std::stringstream traceLog; + traceLog << "Found OpenCL devices: " << platformDevices.size() << "\n"; + for (const auto& device : platformDevices) { - logger()->trace("Device {}", GetDeviceDescription(device)); + traceLog << "Device " << GetDeviceDescription(device) << "\n"; } + logger()->trace(traceLog.str()); } - devices.insert( - devices.end(), - std::make_move_iterator(platformDevices.begin()), - std::make_move_iterator(platformDevices.end())); + + auto deviceToOclsDevice = [](const cl::Device& device) { + return ocls::Device( + device, CalculateDeviceID(device), GetDeviceDescription(device), GetDevicePowerIndex(device)); + }; + + std::transform( + platformDevices.begin(), platformDevices.end(), std::back_inserter(devices), deviceToOclsDevice); } - catch (cl::Error& err) + catch (const cl::Error& err) { logger()->error("Failed to find the platform's devices, {}", err.what()); } } return devices; } - - uint32_t GetDeviceID(const cl::Device& device) - { - return CalculateDeviceID(device); - } - - std::string GetDeviceDescription(const cl::Device& device) - { - try - { - auto name = device.getInfo(); - auto type = device.getInfo(); - auto version = device.getInfo(); - auto vendor = device.getInfo(); - auto vendorID = device.getInfo(); - auto driverVersion = device.getInfo(); - auto description = "name: " + std::move(name) + "; " + "type: " + std::to_string(type) + "; " + - "version: " + std::move(version) + "; " + "vendor: " + std::move(vendor) + "; " + - "vendorID: " + std::to_string(vendorID) + "; " + "driverVersion: " + std::move(driverVersion); - return description; - } - catch (cl::Error& err) - { - logger()->error("Failed to get the device's description, {}", err.what()); - } - return "unknown"; - } - - size_t GetDevicePowerIndex(const cl::Device& device) - { - try - { - const size_t maxComputeUnits = device.getInfo(); - const size_t maxClockFrequency = device.getInfo(); - return maxComputeUnits * maxClockFrequency; - } - catch (const cl::Error& err) - { - logger()->error("Failed to get the device's power index, {}", err.what()); - } - return 0; - } }; } // namespace diff --git a/src/diagnostics.cpp b/src/diagnostics.cpp index e9679bb..153f69f 100644 --- a/src/diagnostics.cpp +++ b/src/diagnostics.cpp @@ -5,12 +5,11 @@ // Created by Ilia Shoshin on 7/16/21. // +#include "device.hpp" #include "diagnostics.hpp" #include "log.hpp" #include "utils.hpp" -#include - #include #include #include @@ -21,7 +20,10 @@ using namespace nlohmann; namespace { -auto logger() { return spdlog::get(ocls::LogName::diagnostics); } +auto logger() +{ + return spdlog::get(ocls::LogName::diagnostics); +} } // namespace @@ -103,18 +105,19 @@ class Diagnostics final : public IDiagnostics void SetBuildOptions(const std::string& options); void SetMaxProblemsCount(uint64_t maxNumberOfProblems); void SetOpenCLDevice(uint32_t identifier); + std::optional GetDevice() const; std::string GetBuildLog(const Source& source); nlohmann::json GetDiagnostics(const Source& source); private: - std::optional SelectOpenCLDevice(const std::vector& devices, uint32_t identifier); - std::optional SelectOpenCLDeviceByPowerIndex(const std::vector& devices); + std::optional SelectOpenCLDevice(const std::vector& devices, uint32_t identifier); + std::optional SelectOpenCLDeviceByPowerIndex(const std::vector& devices); std::string BuildSource(const std::string& source) const; private: std::shared_ptr m_clInfo; std::shared_ptr m_parser; - std::optional m_device; + std::optional m_device; std::string m_BuildOptions; uint64_t m_maxNumberOfProblems = INT8_MAX; }; @@ -173,8 +176,12 @@ void Diagnostics::SetOpenCLDevice(uint32_t identifier) } m_device = selectedDevice; - auto description = m_clInfo->GetDeviceDescription(*m_device); - logger()->debug("Selected OpenCL device: {}", description); + logger()->debug("Selected OpenCL device: {}", (*m_device).GetDescription()); +} + +std::optional Diagnostics::GetDevice() const +{ + return m_device; } std::string Diagnostics::GetBuildLog(const Source& source) @@ -202,12 +209,10 @@ nlohmann::json Diagnostics::GetDiagnostics(const Source& source) // - -std::optional Diagnostics::SelectOpenCLDeviceByPowerIndex(const std::vector& devices) +std::optional Diagnostics::SelectOpenCLDeviceByPowerIndex(const std::vector& devices) { - auto maxIt = std::max_element(devices.begin(), devices.end(), [this](const cl::Device& a, const cl::Device& b) { - const auto powerIndexA = m_clInfo->GetDevicePowerIndex(a); - const auto powerIndexB = m_clInfo->GetDevicePowerIndex(b); - return powerIndexA < powerIndexB; + auto maxIt = std::max_element(devices.begin(), devices.end(), [](const ocls::Device& a, const ocls::Device& b) { + return a.GetPowerIndex() < b.GetPowerIndex(); }); if (maxIt == devices.end()) @@ -218,15 +223,16 @@ std::optional Diagnostics::SelectOpenCLDeviceByPowerIndex(const std: return *maxIt; } -std::optional Diagnostics::SelectOpenCLDevice(const std::vector& devices, uint32_t identifier) +std::optional Diagnostics::SelectOpenCLDevice( + const std::vector& devices, uint32_t identifier) { if (identifier > 0) { logger()->trace("Searching for the device by ID '{}'...", identifier); - auto it = std::find_if(devices.begin(), devices.end(), [this, &identifier](const cl::Device& device) { + auto it = std::find_if(devices.begin(), devices.end(), [&identifier](const ocls::Device& device) { try { - return m_clInfo->GetDeviceID(device) == identifier; + return device.GetID() == identifier; } catch (const cl::Error&) { @@ -243,7 +249,7 @@ std::optional Diagnostics::SelectOpenCLDevice(const std::vectortrace("Searching for the device by power index..."); auto device = SelectOpenCLDeviceByPowerIndex(devices); - if (device && (!m_device || m_clInfo->GetDevicePowerIndex(*device) > m_clInfo->GetDevicePowerIndex(*m_device))) + if (device && (!m_device || (*device).GetPowerIndex() > (*m_device).GetPowerIndex())) { return device; } @@ -258,12 +264,12 @@ std::string Diagnostics::BuildSource(const std::string& source) const throw std::runtime_error("missing OpenCL device"); } - std::vector ds {*m_device}; + std::vector ds {(*m_device).getUnderlyingDevice()}; cl::Context context(ds, NULL, NULL, NULL); cl::Program program; try { - if(m_BuildOptions.empty()) + if (m_BuildOptions.empty()) { logger()->trace("Building program..."); } @@ -287,7 +293,7 @@ std::string Diagnostics::BuildSource(const std::string& source) const try { - program.getBuildInfo(*m_device, CL_PROGRAM_BUILD_LOG, &build_log); + program.getBuildInfo(m_device->getUnderlyingDevice(), CL_PROGRAM_BUILD_LOG, &build_log); } catch (cl::Error& err) { diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 800d788..b8b6eef 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -41,6 +41,9 @@ target_compile_definitions(${TESTS_PROJECT_NAME} PRIVATE ) endif() +target_compile_definitions(${TESTS_PROJECT_NAME} PRIVATE + ENABLE_TESTING=1 +) target_link_libraries (${TESTS_PROJECT_NAME} ${libs}) target_include_directories(${TESTS_PROJECT_NAME} PRIVATE "${PROJECT_SOURCE_DIR}/include" diff --git a/tests/diagnostics-tests.cpp b/tests/diagnostics-tests.cpp index 57bf4da..14f7b1a 100644 --- a/tests/diagnostics-tests.cpp +++ b/tests/diagnostics-tests.cpp @@ -16,66 +16,68 @@ using namespace ocls; using namespace testing; +namespace { -TEST(DiagnosticsTest, SelectDeviceBasedOnPowerIndexDuringTheCreation) +const uint32_t deviceID1 = 12345678; +const uint32_t deviceID2 = 23456789; + +std::vector GetTestDevices() { - auto mockCLInfo = std::make_shared(); - std::vector devices = {cl::Device(), cl::Device()}; + return {Device(deviceID1, "Test Device 1", 10), Device(deviceID2, "Test Device 2", 20)}; +} + +class DiagnosticsTest : public ::testing::Test +{ +protected: + std::shared_ptr mockCLInfo; + + void SetUp() override + { + mockCLInfo = std::make_shared(); + } +}; + +} // namespace - EXPECT_CALL(*mockCLInfo, GetDevices()).WillOnce(Return(devices)); - EXPECT_CALL(*mockCLInfo, GetDevicePowerIndex(_)).WillOnce(Return(10)).WillOnce(Return(10)); - EXPECT_CALL(*mockCLInfo, GetDeviceDescription(_)).WillOnce(Return("")); +TEST_F(DiagnosticsTest, SelectDeviceBasedOnPowerIndexDuringTheCreation) +{ + EXPECT_CALL(*mockCLInfo, GetDevices()).WillOnce(Return(GetTestDevices())); + + auto diagnostics = CreateDiagnostics(mockCLInfo); - CreateDiagnostics(mockCLInfo); + ASSERT_TRUE(diagnostics->GetDevice().has_value()); + EXPECT_EQ(diagnostics->GetDevice().value().GetID(), deviceID2); } -TEST(DiagnosticsTest, SelectDeviceBasedOnPowerIndex) +TEST_F(DiagnosticsTest, SelectDeviceBasedOnPowerIndex) { - auto mockCLInfo = std::make_shared(); - std::vector devices = {cl::Device(), cl::Device()}; - - EXPECT_CALL(*mockCLInfo, GetDevices()).Times(2).WillRepeatedly(Return(devices)); - EXPECT_CALL(*mockCLInfo, GetDevicePowerIndex(_)) - .WillOnce(Return(10)) - .WillOnce(Return(20)) - .WillOnce(Return(10)) - .WillOnce(Return(20)) - .WillOnce(Return(20)) - .WillOnce(Return(20)); - EXPECT_CALL(*mockCLInfo, GetDeviceDescription(_)).WillOnce(Return("")); + EXPECT_CALL(*mockCLInfo, GetDevices()).Times(2).WillRepeatedly(Return(GetTestDevices())); + auto diagnostics = CreateDiagnostics(mockCLInfo); diagnostics->SetOpenCLDevice(0); + + ASSERT_TRUE(diagnostics->GetDevice().has_value()); + EXPECT_EQ(diagnostics->GetDevice().value().GetID(), deviceID2); } -TEST(DiagnosticsTest, SelectDeviceBasedOnExistingIndex) +TEST_F(DiagnosticsTest, SelectDeviceBasedOnExistingIndex) { - auto mockCLInfo = std::make_shared(); - std::vector devices = {cl::Device(), cl::Device()}; + EXPECT_CALL(*mockCLInfo, GetDevices()).Times(2).WillRepeatedly(Return(GetTestDevices())); - EXPECT_CALL(*mockCLInfo, GetDevices()).Times(2).WillRepeatedly(Return(devices)); - EXPECT_CALL(*mockCLInfo, GetDevicePowerIndex(_)).WillOnce(Return(10)).WillOnce(Return(20)); - EXPECT_CALL(*mockCLInfo, GetDeviceID(_)).WillOnce(Return(3138399603)).WillOnce(Return(2027288592)); - EXPECT_CALL(*mockCLInfo, GetDeviceDescription(_)).Times(2).WillRepeatedly(Return("")); auto diagnostics = CreateDiagnostics(mockCLInfo); - diagnostics->SetOpenCLDevice(2027288592); + diagnostics->SetOpenCLDevice(12345678); + + ASSERT_TRUE(diagnostics->GetDevice().has_value()); + EXPECT_EQ(diagnostics->GetDevice().value().GetID(), deviceID1); } -TEST(DiagnosticsTest, SelectDeviceBasedOnNonExistingIndex) +TEST_F(DiagnosticsTest, SelectDeviceBasedOnNonExistingIndex) { - auto mockCLInfo = std::make_shared(); - std::vector devices = {cl::Device(), cl::Device()}; - - EXPECT_CALL(*mockCLInfo, GetDevices()).Times(2).WillRepeatedly(Return(devices)); - EXPECT_CALL(*mockCLInfo, GetDeviceID(_)).WillOnce(Return(3138399603)).WillOnce(Return(2027288592)); - EXPECT_CALL(*mockCLInfo, GetDevicePowerIndex(_)) - .WillOnce(Return(10)) - .WillOnce(Return(20)) - .WillOnce(Return(10)) - .WillOnce(Return(20)) - .WillOnce(Return(20)) - .WillOnce(Return(20)); - EXPECT_CALL(*mockCLInfo, GetDeviceDescription(_)).WillOnce(Return("")); + EXPECT_CALL(*mockCLInfo, GetDevices()).Times(2).WillRepeatedly(Return(GetTestDevices())); auto diagnostics = CreateDiagnostics(mockCLInfo); - diagnostics->SetOpenCLDevice(static_cast(4527288514)); + diagnostics->SetOpenCLDevice(10000000); + + ASSERT_TRUE(diagnostics->GetDevice().has_value()); + EXPECT_EQ(diagnostics->GetDevice().value().GetID(), deviceID2); } diff --git a/tests/mocks/clinfo-mock.hpp b/tests/mocks/clinfo-mock.hpp index 07a0535..4c6bb44 100644 --- a/tests/mocks/clinfo-mock.hpp +++ b/tests/mocks/clinfo-mock.hpp @@ -14,11 +14,5 @@ class CLInfoMock : public ocls::ICLInfo public: MOCK_METHOD(nlohmann::json, json, (), (override)); - MOCK_METHOD(std::vector, GetDevices, (), (override)); - - MOCK_METHOD(uint32_t, GetDeviceID, (const cl::Device&), (override)); - - MOCK_METHOD(std::string, GetDeviceDescription, (const cl::Device&), (override)); - - MOCK_METHOD(size_t, GetDevicePowerIndex, (const cl::Device&), (override)); + MOCK_METHOD(std::vector, GetDevices, (), (override)); }; diff --git a/tests/mocks/diagnostics-mock.hpp b/tests/mocks/diagnostics-mock.hpp index 0702f2e..7ca73fd 100644 --- a/tests/mocks/diagnostics-mock.hpp +++ b/tests/mocks/diagnostics-mock.hpp @@ -20,7 +20,9 @@ class DiagnosticsMock : public ocls::IDiagnostics MOCK_METHOD(void, SetOpenCLDevice, (uint32_t), (override)); + MOCK_METHOD(std::optional, GetDevice, (), (const, override)); + MOCK_METHOD(std::string, GetBuildLog, (const ocls::Source&), (override)); - + MOCK_METHOD(nlohmann::json, GetDiagnostics, (const ocls::Source&), (override)); }; From c43e922118799a7291b0cb0a09a890b530cc59ec Mon Sep 17 00:00:00 2001 From: Galarius Date: Fri, 25 Aug 2023 19:19:47 +0300 Subject: [PATCH 30/32] Update .gitignore --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 4ed1873..be0ec38 100644 --- a/.gitignore +++ b/.gitignore @@ -32,6 +32,7 @@ test_package/* .build* ## VS Code +.vs .vscode CMakePresets.json From 8e7d85107b53ec2003f10f34482c4dd154198bfa Mon Sep 17 00:00:00 2001 From: Galarius Date: Fri, 25 Aug 2023 20:02:29 +0300 Subject: [PATCH 31/32] Return a valid JSON when no OpenCL platforms found --- src/clinfo.cpp | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/src/clinfo.cpp b/src/clinfo.cpp index 5510825..a5dac10 100644 --- a/src/clinfo.cpp +++ b/src/clinfo.cpp @@ -526,19 +526,16 @@ class CLInfo final : public ICLInfo public: nlohmann::json json() { - const auto platforms = GetPlatforms(); - if (platforms.size() == 0) - { - return {}; - } - std::vector jsonPlatforms; - for (const auto& platform : platforms) + const auto platforms = GetPlatforms(); + if (platforms.size() > 0) { - logger()->trace("{}", GetPlatformDescription(platform)); - jsonPlatforms.emplace_back(GetPlatformJSONInfo(platform)); + for (const auto& platform : platforms) + { + logger()->trace("{}", GetPlatformDescription(platform)); + jsonPlatforms.emplace_back(GetPlatformJSONInfo(platform)); + } } - return nlohmann::json {{"PLATFORMS", jsonPlatforms}}; } From 5878f4ebb38425ab9d2f5cbe2fbe4d8f63cbf719 Mon Sep 17 00:00:00 2001 From: Galarius Date: Fri, 25 Aug 2023 19:20:07 +0300 Subject: [PATCH 32/32] Bump version to 0.6.0 --- version | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/version b/version index 2411653..09a3acf 100644 --- a/version +++ b/version @@ -1 +1 @@ -0.5.2 \ No newline at end of file +0.6.0 \ No newline at end of file