Skip to content

Commit

Permalink
feat: correct linking & try implement tensor
Browse files Browse the repository at this point in the history
  • Loading branch information
hans00 committed Jun 4, 2024
1 parent 4d69d69 commit 6ab60f4
Show file tree
Hide file tree
Showing 7 changed files with 183 additions and 84 deletions.
7 changes: 7 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@ cpp_macros = "0.5"

[build-dependencies]
cpp_build = "0.5"
build-target = "0.4"
122 changes: 68 additions & 54 deletions build.rs
Original file line number Diff line number Diff line change
@@ -1,68 +1,82 @@
extern crate cpp_build;
extern crate build_target;
use std::path::Path;
use build_target::Os;

#[cfg(target_os = "macos")]
static SO_EXT: &str = "dylib";
#[cfg(target_os = "linux")]
static SO_EXT: &str = "so";
#[cfg(target_os = "windows")]
static SO_EXT: &str = "dll";

static ET_LIBS: [&str; 16] = [
"extension_data_loader",
"mpsdelegate",
"qnn_executorch_backend",
"portable_ops_lib",
"extension_module",
"xnnpack_backend",
"XNNPACK",
"cpuinfo",
"pthreadpool",
"vulkan_backend",
"optimized_kernels",
"optimized_ops_lib",
"optimized_native_cpu_ops_lib",
"quantized_kernels",
"quantized_ops_lib",
"custom_ops",
];
fn link_lib(lib_path: &Path, lib: &str, whole_link: bool) -> Result<(), ()> {
let so_ext = match build_target::target_os().unwrap() {
Os::Linux => "so",
Os::MacOs => "dylib",
Os::Windows => "dll",
_ => panic!("Unsupported OS"),
};
let filename = match lib {
"extension_module" => format!("lib{}.{}", lib, so_ext),
"qnn_executorch_backend" => format!("lib{}.{}", lib, so_ext),
_ => format!("lib{}.a", lib),
};
if lib_path.join(&filename).exists() {
if filename.ends_with(so_ext) {
println!("cargo:rustc-link-lib=dylib={}", lib);
} else {
if whole_link {
println!("cargo:rustc-link-lib=static:+whole-archive={}", lib);
} else {
println!("cargo:rustc-link-lib=static={}", lib);
}
}
return Ok(());
} else {
eprintln!("{} not found", filename);
}
Err(())
}

fn main() {
let base_path_str = std::env::var("EXECUTORCH_INSTALL_PREFIX").unwrap_or_else(|_| "executorch/cmake-out".to_string());
let lib_path = Path::new(&base_path_str).join("lib");
println!("cargo:rerun-if-changed=src/sampler.rs");
println!("cargo:rerun-if-changed=src/tensor.rs");

let mut lib_paths: Vec<String> = Vec::new();
for lib in ET_LIBS {
let filename = match lib {
"extension_module" => format!("lib{}.{}", lib, SO_EXT),
"qnn_executorch_backend" => format!("lib{}.{}", lib, SO_EXT),
_ => format!("lib{}.a", lib),
};
if lib_path.join(&filename).exists() {
lib_paths.push(filename);
}
let base_path = std::env::var("EXECUTORCH_INSTALL_PREFIX").unwrap_or_else(|_| "executorch/cmake-out".to_string());
let lib_path = Path::new(&base_path).join("lib");

println!("cargo:rustc-link-search=native={}", lib_path.display());

assert!(link_lib(&lib_path, "executorch", false).is_ok());
assert!(link_lib(&lib_path, "extension_module", false).is_ok());
assert!(link_lib(&lib_path, "extension_data_loader", false).is_ok());

// Optimized Kernels
if link_lib(&lib_path, "optimized_native_cpu_ops_lib", true).is_ok() {
assert!(link_lib(&lib_path, "optimized_kernels", false).is_ok());
assert!(link_lib(&lib_path, "portable_kernels", false).is_ok());
assert!(link_lib(&lib_path, "cpublas", false).is_ok());
assert!(link_lib(&lib_path, "eigen_blas", false).is_ok());
} else {
assert!(link_lib(&lib_path, "portable_ops_lib", true).is_ok());
assert!(link_lib(&lib_path, "portable_kernels", false).is_ok());
}

assert!(!lib_paths.is_empty(), "No lib files found in executorch/cmake-out/lib");
// Quantized Kernels
if link_lib(&lib_path, "quantized_kernels", false).is_ok() {
assert!(link_lib(&lib_path, "quantized_ops_lib", false).is_ok());
}

let mut config = cpp_build::Config::new();
config.flag("-std=c++17");
// misc.
let _ = link_lib(&lib_path, "cpuinfo", false);
let _ = link_lib(&lib_path, "pthreadpool", false);

for lib in lib_paths {
config.flag(&format!("-Wl,--whole-archive -l{} -Wl,--no-whole-archive", lib.trim_start_matches("lib").trim_end_matches(".a")));
// if lib.ends_with(SO_EXT) {
// // println!("cargo:rustc-link-lib=dylib={}", lib.trim_start_matches("lib").trim_end_matches(SO_EXT));
// } else {
// config.flag(&format!("--whole-archive -l{} --no-whole-archive", lib));
// // println!("cargo:rustc-link-lib=static={}", lib.trim_start_matches("lib").trim_end_matches(".a"));
// }
// XNNPACK
if link_lib(&lib_path, "xnnpack_backend", true).is_ok() {
assert!(link_lib(&lib_path, "XNNPACK", false).is_ok());
}
// println!("cargo:rustc-link-search=native={}", lib_path.display());

config.flag(&format!("-L{}", lib_path.display()));

config.build("src/lib.rs");
// Vulkan
let _ = link_lib(&lib_path, "vulkan_backend", true);

// tip rebuild if the library changes
println!("cargo:rerun-if-changed=src/sampler.rs");
// QNN
let _ = link_lib(&lib_path, "qnn_executorch_backend", true);

cpp_build::Config::new()
.flag("-std=c++17")
.build("src/lib.rs");
}
15 changes: 13 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,16 @@
use neon::prelude::*;
pub mod sampler;
use sampler::Sampler;
pub mod tensor;
pub mod macros;

use neon::prelude::*;
use sampler::*;
use tensor::*;

fn create_u8_tensor(mut cx: FunctionContext) -> JsResult<JsBox<Tensor::<u8>>> {
let shape = get_arg_as_vec!(cx, 0, JsNumber, i32);
let data = get_arg_as_vec!(cx, 1, JsNumber, u8);
Ok(cx.boxed(Tensor::<u8>::new(shape, data)))
}

fn create_sampler(mut cx: FunctionContext) -> JsResult<JsBox<Sampler>> {
let vocab_size = cx.argument::<JsNumber>(0)?.value(&mut cx) as i32;
Expand All @@ -24,5 +34,6 @@ fn sampler_sample(mut cx: FunctionContext) -> JsResult<JsNumber> {
fn main(mut cx: ModuleContext) -> NeonResult<()> {
cx.export_function("createSampler", create_sampler)?;
cx.export_function("samplerSample", sampler_sample)?;
cx.export_function("createU8Tensor", create_u8_tensor)?;
Ok(())
}
10 changes: 10 additions & 0 deletions src/macros.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#[macro_export]
macro_rules! get_arg_as_vec {
($cx:ident, $index:expr, $js_type:ty, $type:ty) => {
$cx.argument::<JsArray>($index)?
.to_vec(&mut $cx)?
.iter()
.map(|value| value.downcast::<$js_type, _>(&mut $cx).unwrap().value(&mut $cx) as $type)
.collect()
};
}
3 changes: 1 addition & 2 deletions src/sampler.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
use neon::prelude::*;
use neon::types::Finalize;
use cpp::{cpp, cpp_class};

cpp! {{
#include <vector>
#include <executorch/examples/models/llama2/sampler/sampler.h>
#include <executorch/examples/models/llama2/sampler/sampler.cpp>
}}
Expand Down
109 changes: 83 additions & 26 deletions src/tensor.rs
Original file line number Diff line number Diff line change
@@ -1,33 +1,90 @@
use neon::prelude::*;
use neon::types::Finalize;
use cpp::{cpp, cpp_class};

cpp! {{
#include <vector>
#include <executorch/examples/models/llama2/sampler/sampler.h>
#include <executorch/runtime/core/exec_aten/exec_aten.h>
}}

cpp_class!(pub unsafe struct Sampler as "torch::executor::Sampler");

impl Sampler {
pub fn new(vocab_size: i32, temperature: f32, topp: f32, rng_seed: u64) -> Self {
unsafe {
cpp!([vocab_size as "int", temperature as "float", topp as "float", rng_seed as "uint64_t"] -> Sampler as "torch::executor::Sampler" {
return torch::executor::Sampler(vocab_size, temperature, topp, rng_seed);
})
}
}

pub fn sample(&self, param : Vec<f32>) -> i32 {
unsafe {
cpp!([self as "torch::executor::Sampler*", param as "std::vector<float>"] -> i32 as "int32_t" {
auto data = new float[param.size()];
memcpy(data, param.data(), param.size() * sizeof(float));
auto result = self->sample(data);
delete[] data;
return result;
})
}
}
pub enum TensorType {
UInt8 = 0,
Int8 = 1,
Int16 = 2,
Int32 = 3,
Int64 = 4,
// Float16 = 5,
Float32 = 6,
Float64 = 7,
Bool = 11,
}

impl Finalize for Sampler {}
impl From<i32> for TensorType {
fn from(value: i32) -> Self {
match value {
0 => TensorType::UInt8,
1 => TensorType::Int8,
2 => TensorType::Int16,
3 => TensorType::Int32,
4 => TensorType::Int64,
6 => TensorType::Float32,
7 => TensorType::Float64,
11 => TensorType::Bool,
_ => panic!("Invalid dtype"),
}
}
}

cpp_class!(unsafe struct AtenTensor as "exec_aten::Tensor");

impl AtenTensor {
fn new<T>(dtype: TensorType, dim: i64, shape: *mut i32, data: *mut T) -> Self {
let dtype_num = dtype as i32;
unsafe {
cpp!([dtype_num as "int32_t", dim as "ssize_t", shape as "int32_t*", data as "void*"] -> AtenTensor as "exec_aten::Tensor" {
auto tensor_impl = new exec_aten::TensorImpl(
static_cast<exec_aten::ScalarType>(dtype_num),
dim,
shape,
data
);
return exec_aten::Tensor(tensor_impl);
})
}
}

fn dim(&self) -> i64 {
unsafe {
cpp!([self as "exec_aten::Tensor*"] -> i64 as "ssize_t" {
return self->dim();
})
}
}
}

pub struct Tensor<T> {
tensor: AtenTensor,
data: Vec<T>,
shape: Vec<i32>,
}

impl<T> Tensor<T> {
pub fn dim(&self) -> i64 {
self.tensor.dim()
}
}

impl<T> Finalize for Tensor<T> {}

// u8
impl Tensor<u8> {
pub fn new(mut shape: Vec<i32>, mut data: Vec<u8>) -> Self {
let dim = shape.len() as i64;
let shape_ptr = shape.as_mut_ptr();
let data_ptr = data.as_mut_ptr();
let tensor = AtenTensor::new(TensorType::UInt8, dim, shape_ptr, data_ptr);
Tensor {
tensor,
data,
shape,
}
}
}

0 comments on commit 6ab60f4

Please sign in to comment.