Skip to content

Commit

Permalink
fix: fix random test fail
Browse files Browse the repository at this point in the history
  • Loading branch information
hans00 committed May 29, 2024
1 parent f349681 commit 2135056
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 21 deletions.
46 changes: 25 additions & 21 deletions src/Tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -299,9 +299,9 @@ Napi::Value Tensor::Slice(const Napi::CallbackInfo &info) {
}

ssize_t elem_size = tensor_->element_size();
void *newData = malloc(n_elem * elem_size);
char *newData = new char[n_elem * elem_size];

const void *data = tensor_->const_data_ptr();
const char *data = (char*) tensor_->const_data_ptr();

for (size_t i = 0; i < n_elem; i++) {
size_t offset = 0;
Expand All @@ -313,8 +313,8 @@ Napi::Value Tensor::Slice(const Napi::CallbackInfo &info) {
pos /= dim_size;
offset += (startVec[j] + dim_pos) * stride;
}
memcpy(reinterpret_cast<uint8_t *>(newData) + i * elem_size,
reinterpret_cast<const uint8_t *>(data) + offset * elem_size,
memcpy(newData + i * elem_size,
data + offset * elem_size,
elem_size);
}

Expand Down Expand Up @@ -349,7 +349,7 @@ Napi::Value Tensor::Concat(const Napi::CallbackInfo &info) {
}

size_t axis = info.Length() > 1 ? info[1].ToNumber().Int32Value() : 0;
std::vector<Tensor *> tensors(n_tensors);
std::vector<exec_aten::Tensor *> tensors(n_tensors);
std::vector<size_t> sizes;
size_t rank = 0;
exec_aten::ScalarType dtype;
Expand All @@ -362,34 +362,35 @@ Napi::Value Tensor::Concat(const Napi::CallbackInfo &info) {
.ThrowAsJavaScriptException();
return env.Undefined();
}
auto tensor = Napi::ObjectWrap<Tensor>::Unwrap(item.As<Napi::Object>());
auto tensor = Napi::ObjectWrap<Tensor>::Unwrap(item.As<Napi::Object>())->
GetTensorPtr();
tensors[i] = tensor;
if (i == 0) {
dtype = tensor->tensor_->scalar_type();
rank = tensor->tensor_->dim();
dtype = tensor->scalar_type();
rank = tensor->dim();
sizes.resize(rank);
for (size_t j = 0; j < rank; j++) {
sizes[j] = tensor->tensor_->size(j);
sizes[j] = tensor->size(j);
}
if (axis >= rank) {
Napi::TypeError::New(env, "Invalid axis").ThrowAsJavaScriptException();
return env.Undefined();
}
} else if (dtype != tensor->tensor_->scalar_type()) {
} else if (dtype != tensor->scalar_type()) {
Napi::TypeError::New(env, "Tensors have different dtypes")
.ThrowAsJavaScriptException();
return env.Undefined();
} else if (rank != tensor->tensor_->dim()) {
} else if (rank != tensor->dim()) {
Napi::TypeError::New(env, "Tensors have different ranks")
.ThrowAsJavaScriptException();
return env.Undefined();
} else {
for (size_t j = 0; j < rank; j++) {
if (j == axis) {
sizes[j] += tensor->tensor_->size(j);
sizes[j] += tensor->size(j);
continue;
}
if (sizes[j] != tensor->tensor_->size(j) && j != axis) {
if (sizes[j] != tensor->size(j) && j != axis) {
Napi::TypeError::New(env, "Tensors have different sizes")
.ThrowAsJavaScriptException();
return env.Undefined();
Expand All @@ -402,25 +403,25 @@ Napi::Value Tensor::Concat(const Napi::CallbackInfo &info) {
for (size_t i = 0; i < rank; i++) {
n_elem *= sizes[i];
}
ssize_t elem_size = tensors[0]->tensor_->element_size();
void *newData = malloc(n_elem * elem_size);
ssize_t elem_size = tensors[0]->element_size();
char *newData = new char[n_elem * elem_size];

size_t trip_step = 1;
for (size_t j = 0; j < axis; j++) {
trip_step *= tensors[0]->tensor_->size(j);
trip_step *= tensors[0]->size(j);
}

size_t chunk_size = elem_size;
for (size_t k = axis; k < rank; k++) {
chunk_size *= tensors[0]->tensor_->size(k);
chunk_size *= tensors[0]->size(k);
}

for (size_t i = 0; i < trip_step; i++) {
for (size_t j = 0; j < n_tensors; j++) {
const void *data = tensors[j]->tensor_->const_data_ptr();
memcpy(reinterpret_cast<uint8_t *>(newData) + j * chunk_size +
const char *data = (char*) tensors[j]->const_data_ptr();
memcpy(newData + j * chunk_size +
i * n_tensors * chunk_size,
reinterpret_cast<const uint8_t *>(data) + chunk_size * i,
data + chunk_size * i,
chunk_size);
}
}
Expand Down Expand Up @@ -462,8 +463,11 @@ Napi::Value Tensor::Reshape(const Napi::CallbackInfo &info) {
return env.Undefined();
}

char *data = new char[tensor_->nbytes()];
memcpy(data, tensor_->const_data_ptr(), tensor_->nbytes());

exec_aten::Tensor tensor(new exec_aten::TensorImpl(
tensor_->scalar_type(), rank, dims, tensor_->mutable_data_ptr()));
tensor_->scalar_type(), rank, dims, data));
return Tensor::New(tensor);
}

Expand Down
1 change: 1 addition & 0 deletions src/Tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class Tensor : public Napi::ObjectWrap<Tensor> {
}

inline exec_aten::Tensor GetTensor() { return *tensor_; }
inline exec_aten::Tensor* GetTensorPtr() { return tensor_.get(); }

protected:
Napi::Value Shape(const Napi::CallbackInfo &info);
Expand Down

0 comments on commit 2135056

Please sign in to comment.