Skip to content

Commit

Permalink
feat: support manually dispose
Browse files Browse the repository at this point in the history
  • Loading branch information
hans00 committed May 29, 2024
1 parent 83f0e3f commit 8a539d6
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 2 deletions.
2 changes: 2 additions & 0 deletions lib/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ interface TensorImpl {
setIndex(position: Array<number>, data: number|boolean): void
slice(slice_position: Array<Optional<Optional<number>[]>>): TensorImpl;
reshape(shape: number[]): TensorImpl;
dispose(): void;
}

interface Tensor {
Expand All @@ -41,6 +42,7 @@ interface ModuleImpl {
forward(inputs: EValue[]): Promise<EValue[]>;
execute(method_name: string, inputs: EValue[]): Promise<EValue[]>;
get method_names(): string[];
dispose(): void;
}

interface Module {
Expand Down
27 changes: 26 additions & 1 deletion src/Module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,11 @@ Napi::Value Module::Execute(const Napi::CallbackInfo &info) {
Napi::Env env = info.Env();
Napi::HandleScope scope(env);

if (!module_) {
Napi::TypeError::New(env, "Module is disposed").ThrowAsJavaScriptException();
return env.Undefined();
}

if (info.Length() < 2) {
Napi::TypeError::New(env, "Expected method name and input array")
.ThrowAsJavaScriptException();
Expand Down Expand Up @@ -180,6 +185,11 @@ Napi::Value Module::LoadMethod(const Napi::CallbackInfo &info) {
Napi::Env env = info.Env();
Napi::HandleScope scope(env);

if (!module_) {
Napi::TypeError::New(env, "Module is disposed").ThrowAsJavaScriptException();
return env.Undefined();
}

if (info.Length() < 1 || !info[0].IsString()) {
Napi::TypeError::New(env, "Expected a string").ThrowAsJavaScriptException();
return env.Undefined();
Expand All @@ -200,6 +210,11 @@ Napi::Value Module::Forward(const Napi::CallbackInfo &info) {
Napi::Env env = info.Env();
Napi::HandleScope scope(env);

if (!module_) {
Napi::TypeError::New(env, "Module is disposed").ThrowAsJavaScriptException();
return env.Undefined();
}

if (info.Length() < 1 || !info[0].IsArray()) {
Napi::TypeError::New(env, "Expected input array")
.ThrowAsJavaScriptException();
Expand All @@ -221,6 +236,11 @@ Napi::Value Module::MethodNames(const Napi::CallbackInfo &info) {
Napi::Env env = info.Env();
Napi::HandleScope scope(env);

if (!module_) {
Napi::TypeError::New(env, "Module is disposed").ThrowAsJavaScriptException();
return env.Undefined();
}

auto result = (*module_)->method_names();
if (result.ok()) {
auto names = result.get();
Expand All @@ -238,14 +258,19 @@ Napi::Value Module::MethodNames(const Napi::CallbackInfo &info) {
}
}

void Module::Dispose(const Napi::CallbackInfo &info) {
module_.reset();
}

Napi::Object Module::Init(Napi::Env env, Napi::Object exports) {
Napi::Function func = DefineClass(
env, "Module",
{StaticMethod("load", &Module::Load),
InstanceAccessor("method_names", &Module::MethodNames, nullptr),
InstanceMethod("loadMethod", &Module::LoadMethod),
InstanceMethod("forward", &Module::Forward),
InstanceMethod("execute", &Module::Execute)});
InstanceMethod("execute", &Module::Execute),
InstanceMethod("dispose", &Module::Dispose)});

constructor = Napi::Persistent(func);
constructor.SuppressDestruct();
Expand Down
1 change: 1 addition & 0 deletions src/Module.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class Module : public Napi::ObjectWrap<Module> {
Napi::Value Forward(const Napi::CallbackInfo &info);
Napi::Value Execute(const Napi::CallbackInfo &info);
Napi::Value MethodNames(const Napi::CallbackInfo &info);
void Dispose(const Napi::CallbackInfo &info);

private:
static Napi::FunctionReference constructor;
Expand Down
36 changes: 35 additions & 1 deletion src/Tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,11 @@ Napi::Value Tensor::Shape(const Napi::CallbackInfo &info) {
Napi::Env env = info.Env();
Napi::HandleScope scope(env);

if (tensor_ == nullptr) {
Napi::TypeError::New(env, "Tensor is disposed").ThrowAsJavaScriptException();
return env.Undefined();
}

Napi::Array shape = Napi::Array::New(env, tensor_->dim());
for (size_t i = 0; i < tensor_->dim(); i++) {
shape.Set(i, Napi::Number::New(env, tensor_->size(i)));
Expand All @@ -126,6 +131,11 @@ Napi::Value Tensor::GetData(const Napi::CallbackInfo &info) {
Napi::Env env = info.Env();
Napi::HandleScope scope(env);

if (tensor_ == nullptr) {
Napi::TypeError::New(env, "Tensor is disposed").ThrowAsJavaScriptException();
return env.Undefined();
}

size_t size = tensor_->nbytes();
size_t n_elem = tensor_->numel();
const void *data = tensor_->const_data_ptr();
Expand Down Expand Up @@ -191,6 +201,11 @@ void Tensor::SetIndex(const Napi::CallbackInfo &info) {
Napi::Env env = info.Env();
Napi::HandleScope scope(env);

if (tensor_ == nullptr) {
Napi::TypeError::New(env, "Tensor is disposed").ThrowAsJavaScriptException();
return;
}

if (info.Length() < 2) {
Napi::TypeError::New(env, "Expected 2 arguments")
.ThrowAsJavaScriptException();
Expand Down Expand Up @@ -262,6 +277,11 @@ Napi::Value Tensor::Slice(const Napi::CallbackInfo &info) {
Napi::Env env = info.Env();
Napi::HandleScope scope(env);

if (tensor_ == nullptr) {
Napi::TypeError::New(env, "Tensor is disposed").ThrowAsJavaScriptException();
return env.Undefined();
}

if (info.Length() < 1 || !info[0].IsArray()) {
Napi::TypeError::New(env, "Expected array").ThrowAsJavaScriptException();
return env.Undefined();
Expand Down Expand Up @@ -364,6 +384,10 @@ Napi::Value Tensor::Concat(const Napi::CallbackInfo &info) {
}
auto tensor = Napi::ObjectWrap<Tensor>::Unwrap(item.As<Napi::Object>())->
GetTensorPtr();
if (tensor == nullptr) {
Napi::TypeError::New(env, "Tensor is disposed").ThrowAsJavaScriptException();
return env.Undefined();
}
tensors[i] = tensor;
if (i == 0) {
dtype = tensor->scalar_type();
Expand Down Expand Up @@ -439,6 +463,11 @@ Napi::Value Tensor::Reshape(const Napi::CallbackInfo &info) {
Napi::Env env = info.Env();
Napi::HandleScope scope(env);

if (tensor_ == nullptr) {
Napi::TypeError::New(env, "Tensor is disposed").ThrowAsJavaScriptException();
return env.Undefined();
}

if (info.Length() < 1 || !info[0].IsArray()) {
Napi::TypeError::New(env, "Expected array").ThrowAsJavaScriptException();
return env.Undefined();
Expand Down Expand Up @@ -471,6 +500,10 @@ Napi::Value Tensor::Reshape(const Napi::CallbackInfo &info) {
return Tensor::New(tensor);
}

void Tensor::Dispose(const Napi::CallbackInfo &info) {
tensor_.reset();
}

Napi::Object Tensor::Init(Napi::Env env, Napi::Object exports) {
Napi::Function func =
DefineClass(env, "Tensor",
Expand All @@ -480,7 +513,8 @@ Napi::Object Tensor::Init(Napi::Env env, Napi::Object exports) {
InstanceAccessor("data", &Tensor::GetData, &Tensor::SetData),
InstanceMethod("setIndex", &Tensor::SetIndex),
InstanceMethod("slice", &Tensor::Slice),
InstanceMethod("reshape", &Tensor::Reshape)});
InstanceMethod("reshape", &Tensor::Reshape),
InstanceMethod("dispose", &Tensor::Dispose)});

constructor = Napi::Persistent(func);
constructor.SuppressDestruct();
Expand Down
1 change: 1 addition & 0 deletions src/Tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class Tensor : public Napi::ObjectWrap<Tensor> {
void SetIndex(const Napi::CallbackInfo &info);
Napi::Value Slice(const Napi::CallbackInfo &info);
Napi::Value Reshape(const Napi::CallbackInfo &info);
void Dispose(const Napi::CallbackInfo &info);

static Napi::Value Concat(const Napi::CallbackInfo &info);

Expand Down

0 comments on commit 8a539d6

Please sign in to comment.