Skip to content

Commit

Permalink
Refactor crypto for light runtime (#1104)
Browse files Browse the repository at this point in the history
  • Loading branch information
mkornaukhov03 authored Sep 17, 2024
1 parent 3ff5b53 commit 3b0a9d6
Show file tree
Hide file tree
Showing 5 changed files with 155 additions and 100 deletions.
112 changes: 15 additions & 97 deletions runtime-light/stdlib/crypto/crypto-functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,39 +7,31 @@
#include "common/tl/constants/common.h"
#include "runtime-light/stdlib/component/component-api.h"
#include "runtime-light/tl/tl-core.h"

namespace {

// Crypto-Specific TL magics

constexpr uint32_t TL_CERT_INFO_ITEM_LONG = 0x533f'f89f;
constexpr uint32_t TL_CERT_INFO_ITEM_STR = 0xc427'feef;
constexpr uint32_t TL_CERT_INFO_ITEM_DICT = 0x1ea8'a774;

constexpr uint32_t TL_GET_PEM_CERT_INFO = 0xa50c'fd6c;
constexpr uint32_t TL_GET_CRYPTOSECURE_PSEUDORANDOM_BYTES = 0x2491'b81d;
} // namespace
#include "runtime-light/tl/tl-functions.h"
#include "runtime-light/tl/tl-types.h"

task_t<Optional<string>> f$openssl_random_pseudo_bytes(int64_t length) noexcept {
if (length <= 0 || length > string::max_size()) {
co_return false;
}

tl::TLBuffer buffer;
buffer.store_trivial<uint32_t>(TL_GET_CRYPTOSECURE_PSEUDORANDOM_BYTES);
buffer.store_trivial<int32_t>(length);

// TODO think about performance when transferring data

tl::GetCryptosecurePseudorandomBytes request{.size = static_cast<int32_t>(length)};

tl::TLBuffer buffer;
request.store(buffer);

string request_buf;
request_buf.append(buffer.data(), buffer.size());

auto query = f$component_client_send_request(string("crypto"), request_buf);
auto query = f$component_client_send_request(string("crypto"), string{buffer.data(), static_cast<string::size_type>(buffer.size())});
string resp = co_await f$component_client_fetch_response(co_await query);

buffer.clean();
buffer.store_bytes(resp.c_str(), resp.size());

// Maybe better to do this in some structure, but there's not much work to do with TL here
std::optional<uint32_t> magic = buffer.fetch_trivial<uint32_t>();
if (!magic.has_value() || *magic != TL_MAYBE_TRUE) {
co_return false;
Expand All @@ -49,10 +41,10 @@ task_t<Optional<string>> f$openssl_random_pseudo_bytes(int64_t length) noexcept
}

task_t<Optional<array<mixed>>> f$openssl_x509_parse(const string &data, bool shortnames) noexcept {
tl::GetPemCertInfo request{.is_short = shortnames, .bytes = data};

tl::TLBuffer buffer;
buffer.store_trivial<uint32_t>(TL_GET_PEM_CERT_INFO);
buffer.store_trivial<uint32_t>(shortnames ? TL_BOOL_TRUE : TL_BOOL_FALSE);
buffer.store_string(std::string_view(data.c_str(), data.size()));
request.store(buffer);

string request_buf;
request_buf.append(buffer.data(), buffer.size());
Expand All @@ -63,84 +55,10 @@ task_t<Optional<array<mixed>>> f$openssl_x509_parse(const string &data, bool sho
buffer.clean();
buffer.store_bytes(resp_from_platform.c_str(), resp_from_platform.size());

if (const auto magic = buffer.fetch_trivial<uint32_t>(); magic.value_or(TL_ZERO) != TL_MAYBE_TRUE) {
tl::GetPemCertInfoResponse response;
if (!response.fetch(buffer)) {
co_return false;
}

if (const auto magic = buffer.fetch_trivial<uint32_t>(); magic.value_or(TL_ZERO) != TL_DICTIONARY) {
co_return false;
}

const std::optional<uint32_t> size = buffer.fetch_trivial<uint32_t>();
if (!size.has_value()) {
co_return false;
}

auto response = array<mixed>::create();

for (uint32_t i = 0; i < size; ++i) {
const auto key_view = buffer.fetch_string();
if (key_view.empty()) {
co_return false;
}

const auto key = string(key_view.data(), key_view.length());

const std::optional<uint32_t> magic = buffer.fetch_trivial<uint32_t>();
if (!magic.has_value()) {
co_return false;
}

switch (*magic) {
case TL_CERT_INFO_ITEM_LONG: {
const std::optional<int64_t> val = buffer.fetch_trivial<int64_t>();
if (!val.has_value()) {
co_return false;
}
response[key] = *val;
break;
}
case TL_CERT_INFO_ITEM_STR: {
const auto value_view = buffer.fetch_string();
if (value_view.empty()) {
co_return false;
}
const auto value = string(value_view.data(), value_view.size());

response[key] = string(value_view.data(), value_view.size());
break;
}
case TL_CERT_INFO_ITEM_DICT: {
auto sub_array = array<string>::create();
const std::optional<uint32_t> sub_size = buffer.fetch_trivial<uint32_t>();

if (!sub_size.has_value()) {
co_return false;
}

for (size_t j = 0; j < sub_size; ++j) {
const auto sub_key_view = buffer.fetch_string();
if (sub_key_view.empty()) {
co_return false;
}
const auto sub_key = string(sub_key_view.data(), sub_key_view.size());

const auto sub_value_view = buffer.fetch_string();
if (sub_value_view.empty()) {
co_return false;
}
const auto sub_value = string(sub_value_view.data(), sub_value_view.size());

sub_array[sub_key] = sub_value;
}
response[key] = sub_array;

break;
}
default:
co_return false;
}
}

co_return response;
co_return response.data;
}
12 changes: 12 additions & 0 deletions runtime-light/tl/tl-functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <cstdint>
#include <string_view>

#include "common/tl/constants/common.h"
#include "runtime-light/tl/tl-core.h"

namespace {
Expand Down Expand Up @@ -45,4 +46,15 @@ void K2InvokeJobWorker::store(TLBuffer &tlb) const noexcept {
tlb.store_string({body.c_str(), body.size()});
}

void GetCryptosecurePseudorandomBytes::store(TLBuffer &tlb) const noexcept {
tlb.store_trivial<uint32_t>(GET_CRYPTOSECURE_PSEUDORANDOM_BYTES_MAGIC);
tlb.store_trivial<int32_t>(size);
}

void GetPemCertInfo::store(TLBuffer &tlb) const noexcept {
tlb.store_trivial<uint32_t>(GET_PEM_CERT_INFO_MAGIC);
tlb.store_trivial<uint32_t>(is_short ? TL_BOOL_TRUE : TL_BOOL_FALSE);
tlb.store_string(std::string_view{bytes.c_str(), bytes.size()});
}

} // namespace tl
25 changes: 23 additions & 2 deletions runtime-light/tl/tl-functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,11 @@

namespace tl {

inline constexpr uint32_t K2_INVOKE_HTTP_MAGIC = 0xd909efe8;
inline constexpr uint32_t K2_INVOKE_JOB_WORKER_MAGIC = 0x437d7312;
// ===== JOB WORKERS =====

inline constexpr uint32_t K2_INVOKE_HTTP_MAGIC = 0xd909'efe8;
inline constexpr uint32_t K2_INVOKE_JOB_WORKER_MAGIC = 0x437d'7312;


struct K2InvokeJobWorker final {
uint64_t image_id{};
Expand All @@ -26,4 +29,22 @@ struct K2InvokeJobWorker final {
void store(TLBuffer &tlb) const noexcept;
};

// ===== CRYPTO =====

inline constexpr uint32_t GET_CRYPTOSECURE_PSEUDORANDOM_BYTES_MAGIC = 0x2491'b81d;
inline constexpr uint32_t GET_PEM_CERT_INFO_MAGIC = 0xa50c'fd6c;

struct GetCryptosecurePseudorandomBytes final {
int32_t size{};

void store(TLBuffer &tlb) const noexcept;
};

struct GetPemCertInfo final {
bool is_short{true};
string bytes;

void store(TLBuffer &tlb) const noexcept;
};

} // namespace tl
92 changes: 92 additions & 0 deletions runtime-light/tl/tl-types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,20 @@
#include <string_view>
#include <tuple>

#include "common/tl/constants/common.h"
#include "runtime-light/tl/tl-core.h"

namespace {

// magic + flags + job_id + minimum string size length
constexpr auto K2_JOB_WORKER_RESPONSE_MIN_SIZE = sizeof(uint32_t) + sizeof(uint32_t) + sizeof(int64_t) + tl::SMALL_STRING_SIZE_LEN;

enum CertInfoItem : uint32_t {
LONG_MAGIC = 0x533f'f89f,
STR_MAGIC = 0xc427'feef,
DICT_MAGIC = 0x1ea8'a774
};

} // namespace

namespace tl {
Expand All @@ -38,4 +45,89 @@ void K2JobWorkerResponse::store(TLBuffer &tlb) const noexcept {
tlb.store_string({body.c_str(), body.size()});
}

bool GetPemCertInfoResponse::fetch(TLBuffer &tlb) noexcept {
if (const auto magic = tlb.fetch_trivial<uint32_t>(); magic.value_or(TL_ZERO) != TL_MAYBE_TRUE) {
return false;
}

if (const auto magic = tlb.fetch_trivial<uint32_t>(); magic.value_or(TL_ZERO) != TL_DICTIONARY) {
return false;
}

const std::optional<uint32_t> size = tlb.fetch_trivial<uint32_t>();
if (!size.has_value()) {
return false;
}

auto response = array<mixed>::create();
response.reserve(*size, false);

for (uint32_t i = 0; i < *size; ++i) {
const auto key_view = tlb.fetch_string();
if (key_view.empty()) {
return false;
}

const auto key = string(key_view.data(), key_view.length());

const std::optional<uint32_t> magic = tlb.fetch_trivial<uint32_t>();
if (!magic.has_value()) {
return false;
}

switch (*magic) {
case CertInfoItem::LONG_MAGIC: {
const std::optional<int64_t> val = tlb.fetch_trivial<int64_t>();
if (!val.has_value()) {
return false;
}
response[key] = *val;
break;
}
case CertInfoItem::STR_MAGIC: {
const auto value_view = tlb.fetch_string();
if (value_view.empty()) {
return false;
}
const auto value = string(value_view.data(), value_view.size());

response[key] = string(value_view.data(), value_view.size());
break;
}
case CertInfoItem::DICT_MAGIC: {
auto sub_array = array<string>::create();
const std::optional<uint32_t> sub_size = tlb.fetch_trivial<uint32_t>();

if (!sub_size.has_value()) {
return false;
}

for (size_t j = 0; j < sub_size; ++j) {
const auto sub_key_view = tlb.fetch_string();
if (sub_key_view.empty()) {
return false;
}
const auto sub_key = string(sub_key_view.data(), sub_key_view.size());

const auto sub_value_view = tlb.fetch_string();
if (sub_value_view.empty()) {
return false;
}
const auto sub_value = string(sub_value_view.data(), sub_value_view.size());

sub_array[sub_key] = sub_value;
}
response[key] = sub_array;

break;
}
default:
return false;
}
}

data = std::move(response);
return true;
}

} // namespace tl
14 changes: 13 additions & 1 deletion runtime-light/tl/tl-types.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@

namespace tl {

inline constexpr uint32_t K2_JOB_WORKER_RESPONSE_MAGIC = 0x3afb3a08;
// ===== JOB WORKERS =====

inline constexpr uint32_t K2_JOB_WORKER_RESPONSE_MAGIC = 0x3afb'3a08;

struct K2JobWorkerResponse final {
int64_t job_id{};
Expand All @@ -22,4 +24,14 @@ struct K2JobWorkerResponse final {
void store(TLBuffer &tlb) const noexcept;
};

// ===== CRYPTO =====

// Actually it's "Maybe (Dictionary CertInfoItem)"
// But I now want to have this logic separately
struct GetPemCertInfoResponse {
array<mixed> data;

bool fetch(TLBuffer &tlb) noexcept;
};

} // namespace tl

0 comments on commit 3b0a9d6

Please sign in to comment.