Skip to content

Commit

Permalink
Skrypt: WaitForModule if IsModuleLoading to get MappedModuleVariable
Browse files Browse the repository at this point in the history
  • Loading branch information
ohhmm committed Mar 20, 2024
1 parent c690111 commit 9894b46
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 26 deletions.
48 changes: 23 additions & 25 deletions libskrypt/skrypt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ const ::omnn::math::Variable& Skrypt::MappedModuleVariable(const ::omnn::math::V
);
if (!module) {
module = Module(moduleName);
} else if (IsModuleLoading(moduleName)) {
if(module != WaitTillModuleLoadingComplete(moduleName))
LOG_AND_IMPLEMENT("Module provided does not match");
}
auto& moduleVarNames = module->GetVarNames();
auto moduleVarIt = moduleVarNames.find(name);
Expand Down Expand Up @@ -411,6 +414,12 @@ bool Skrypt::IsModuleLoading(std::string_view name) const {
return modulesLoading.contains(name);
}

void Skrypt::WaitAllModulesLoadingComplete() {
while (!modulesLoadingQueue.empty()) {
modulesLoadingQueue.PeekNextResult();
}
}

Skrypt::module_t Skrypt::Module(std::string_view name) {
auto module = GetLoadedModule(name);
if (!module) {
Expand Down Expand Up @@ -449,31 +458,21 @@ Skrypt::module_t Skrypt::Module(std::string_view name) {
module->Load(FindModulePath(name));

std::unique_lock lock(modulesLoadingMutex);
auto loading = modulesLoading.find(std::string(name));
auto loading = modulesLoading.find(name);
if (loading != modulesLoading.end()) {
modulesLoading.erase(loading);
} else {
std::cerr << "Module " << name << " was not found in the loading map" << std::endl; // race condition
}
} else {
// wait for the module to be loaded
bool waiting = true;
do {
auto loaded = modulesLoadingQueue.PeekNextResult();
auto it = loaded.find(name);
if (it != loaded.end()) {
if (module != it->second.get()) {
IMPLEMENT
}
waiting = false;
}
} while (waiting);
WaitTillModuleLoadingComplete(name);
}
}
return module;
}

Skrypt::loading_module_t Skrypt::StartLoadingModule(std::string_view name) {
std::cout << "Module " << name << " loading started" << std::endl;
return std::async(
std::launch::async, [this, name]() {
auto module = Module(name);
Expand All @@ -489,20 +488,16 @@ Skrypt::loading_modules_t Skrypt::LoadModules(const ::omnn::math::Valuable& v) {
if (dot != std::string::npos) {
auto moduleName = name.substr(0, dot);
auto loading = loadingModules.find(std::string(moduleName));
if (loading == loadingModules.end()) {
// not loading yet
if (loading == loadingModules.end()) { // not loading here yet
auto loaded = GetLoadedModule(moduleName);
if (loaded)
{
std::shared_lock lock(modulesMapMutex);
auto loaded = modules.find(moduleName);
if (loaded != modules.end()) {
// already loaded
std::promise<module_t> promise;
promise.set_value(loaded->second);
loadingModules.emplace(moduleName, promise.get_future());
continue;
}
std::promise<module_t> promise;
promise.set_value(loaded);
loadingModules.emplace(moduleName, promise.get_future());
} else {
loadingModules.emplace(moduleName, StartLoadingModule(moduleName));
}
loadingModules.emplace(moduleName, StartLoadingModule(moduleName));
}
}
}
Expand Down Expand Up @@ -603,6 +598,9 @@ const ::omnn::math::Valuable::solutions_t& Skrypt::Known(const ::omnn::math::Var
}
}
}
} else {
WaitAllModulesLoadingComplete();
known = std::cref(base::Known(v));
}
}
}
Expand Down
11 changes: 10 additions & 1 deletion libskrypt/skrypt.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,11 +91,20 @@ class Skrypt
moduleFileSearchAdditionalPaths.emplace_back(std::forward<T>(p));
}
bool IsModuleLoading(std::string_view name) const;

loading_module_t StartLoadingModule(std::string_view name);
loading_modules_t LoadModules(const ::omnn::math::Valuable& v);
loading_modules_future_t StartLoadingModules(const ::omnn::math::Valuable& v);
module_t WaitTillModuleLoadingComplete(std::string_view name);
void BackgroudLoadingModules(const ::omnn::math::Valuable& v);

/// <summary>
/// Wait for the module to be loaded
/// </summary>
/// <param name="name"></param>
/// <returns></returns>
module_t WaitTillModuleLoadingComplete(std::string_view name);
void WaitAllModulesLoadingComplete();

std::string_view GetVariableName(const ::omnn::math::Variable&) const;
std::string_view GetModuleName(std::string_view variableName) const;
std::string_view GetModuleName(const ::omnn::math::Variable&) const;
Expand Down

0 comments on commit 9894b46

Please sign in to comment.