diff --git a/src/Tensor.cpp b/src/Tensor.cpp index e5e4b2e..465022a 100644 --- a/src/Tensor.cpp +++ b/src/Tensor.cpp @@ -299,9 +299,9 @@ Napi::Value Tensor::Slice(const Napi::CallbackInfo &info) { } ssize_t elem_size = tensor_->element_size(); - void *newData = malloc(n_elem * elem_size); + char *newData = new char[n_elem * elem_size]; - const void *data = tensor_->const_data_ptr(); + const char *data = (char*) tensor_->const_data_ptr(); for (size_t i = 0; i < n_elem; i++) { size_t offset = 0; @@ -313,8 +313,8 @@ Napi::Value Tensor::Slice(const Napi::CallbackInfo &info) { pos /= dim_size; offset += (startVec[j] + dim_pos) * stride; } - memcpy(reinterpret_cast(newData) + i * elem_size, - reinterpret_cast(data) + offset * elem_size, + memcpy(newData + i * elem_size, + data + offset * elem_size, elem_size); } @@ -349,7 +349,7 @@ Napi::Value Tensor::Concat(const Napi::CallbackInfo &info) { } size_t axis = info.Length() > 1 ? info[1].ToNumber().Int32Value() : 0; - std::vector tensors(n_tensors); + std::vector tensors(n_tensors); std::vector sizes; size_t rank = 0; exec_aten::ScalarType dtype; @@ -362,34 +362,35 @@ Napi::Value Tensor::Concat(const Napi::CallbackInfo &info) { .ThrowAsJavaScriptException(); return env.Undefined(); } - auto tensor = Napi::ObjectWrap::Unwrap(item.As()); + auto tensor = Napi::ObjectWrap::Unwrap(item.As())-> + GetTensorPtr(); tensors[i] = tensor; if (i == 0) { - dtype = tensor->tensor_->scalar_type(); - rank = tensor->tensor_->dim(); + dtype = tensor->scalar_type(); + rank = tensor->dim(); sizes.resize(rank); for (size_t j = 0; j < rank; j++) { - sizes[j] = tensor->tensor_->size(j); + sizes[j] = tensor->size(j); } if (axis >= rank) { Napi::TypeError::New(env, "Invalid axis").ThrowAsJavaScriptException(); return env.Undefined(); } - } else if (dtype != tensor->tensor_->scalar_type()) { + } else if (dtype != tensor->scalar_type()) { Napi::TypeError::New(env, "Tensors have different dtypes") .ThrowAsJavaScriptException(); return env.Undefined(); - } else if (rank != tensor->tensor_->dim()) { + } else if (rank != tensor->dim()) { Napi::TypeError::New(env, "Tensors have different ranks") .ThrowAsJavaScriptException(); return env.Undefined(); } else { for (size_t j = 0; j < rank; j++) { if (j == axis) { - sizes[j] += tensor->tensor_->size(j); + sizes[j] += tensor->size(j); continue; } - if (sizes[j] != tensor->tensor_->size(j) && j != axis) { + if (sizes[j] != tensor->size(j) && j != axis) { Napi::TypeError::New(env, "Tensors have different sizes") .ThrowAsJavaScriptException(); return env.Undefined(); @@ -402,25 +403,25 @@ Napi::Value Tensor::Concat(const Napi::CallbackInfo &info) { for (size_t i = 0; i < rank; i++) { n_elem *= sizes[i]; } - ssize_t elem_size = tensors[0]->tensor_->element_size(); - void *newData = malloc(n_elem * elem_size); + ssize_t elem_size = tensors[0]->element_size(); + char *newData = new char[n_elem * elem_size]; size_t trip_step = 1; for (size_t j = 0; j < axis; j++) { - trip_step *= tensors[0]->tensor_->size(j); + trip_step *= tensors[0]->size(j); } size_t chunk_size = elem_size; for (size_t k = axis; k < rank; k++) { - chunk_size *= tensors[0]->tensor_->size(k); + chunk_size *= tensors[0]->size(k); } for (size_t i = 0; i < trip_step; i++) { for (size_t j = 0; j < n_tensors; j++) { - const void *data = tensors[j]->tensor_->const_data_ptr(); - memcpy(reinterpret_cast(newData) + j * chunk_size + + const char *data = (char*) tensors[j]->const_data_ptr(); + memcpy(newData + j * chunk_size + i * n_tensors * chunk_size, - reinterpret_cast(data) + chunk_size * i, + data + chunk_size * i, chunk_size); } } @@ -462,8 +463,11 @@ Napi::Value Tensor::Reshape(const Napi::CallbackInfo &info) { return env.Undefined(); } + char *data = new char[tensor_->nbytes()]; + memcpy(data, tensor_->const_data_ptr(), tensor_->nbytes()); + exec_aten::Tensor tensor(new exec_aten::TensorImpl( - tensor_->scalar_type(), rank, dims, tensor_->mutable_data_ptr())); + tensor_->scalar_type(), rank, dims, data)); return Tensor::New(tensor); } diff --git a/src/Tensor.h b/src/Tensor.h index 19f5663..0b2d7e9 100644 --- a/src/Tensor.h +++ b/src/Tensor.h @@ -26,6 +26,7 @@ class Tensor : public Napi::ObjectWrap { } inline exec_aten::Tensor GetTensor() { return *tensor_; } + inline exec_aten::Tensor* GetTensorPtr() { return tensor_.get(); } protected: Napi::Value Shape(const Napi::CallbackInfo &info);