-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: correct linking & try implement tensor
- Loading branch information
Showing
7 changed files
with
183 additions
and
84 deletions.
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,3 +20,4 @@ cpp_macros = "0.5" | |
|
||
[build-dependencies] | ||
cpp_build = "0.5" | ||
build-target = "0.4" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,68 +1,82 @@ | ||
extern crate cpp_build; | ||
extern crate build_target; | ||
use std::path::Path; | ||
use build_target::Os; | ||
|
||
#[cfg(target_os = "macos")] | ||
static SO_EXT: &str = "dylib"; | ||
#[cfg(target_os = "linux")] | ||
static SO_EXT: &str = "so"; | ||
#[cfg(target_os = "windows")] | ||
static SO_EXT: &str = "dll"; | ||
|
||
static ET_LIBS: [&str; 16] = [ | ||
"extension_data_loader", | ||
"mpsdelegate", | ||
"qnn_executorch_backend", | ||
"portable_ops_lib", | ||
"extension_module", | ||
"xnnpack_backend", | ||
"XNNPACK", | ||
"cpuinfo", | ||
"pthreadpool", | ||
"vulkan_backend", | ||
"optimized_kernels", | ||
"optimized_ops_lib", | ||
"optimized_native_cpu_ops_lib", | ||
"quantized_kernels", | ||
"quantized_ops_lib", | ||
"custom_ops", | ||
]; | ||
fn link_lib(lib_path: &Path, lib: &str, whole_link: bool) -> Result<(), ()> { | ||
let so_ext = match build_target::target_os().unwrap() { | ||
Os::Linux => "so", | ||
Os::MacOs => "dylib", | ||
Os::Windows => "dll", | ||
_ => panic!("Unsupported OS"), | ||
}; | ||
let filename = match lib { | ||
"extension_module" => format!("lib{}.{}", lib, so_ext), | ||
"qnn_executorch_backend" => format!("lib{}.{}", lib, so_ext), | ||
_ => format!("lib{}.a", lib), | ||
}; | ||
if lib_path.join(&filename).exists() { | ||
if filename.ends_with(so_ext) { | ||
println!("cargo:rustc-link-lib=dylib={}", lib); | ||
} else { | ||
if whole_link { | ||
println!("cargo:rustc-link-lib=static:+whole-archive={}", lib); | ||
} else { | ||
println!("cargo:rustc-link-lib=static={}", lib); | ||
} | ||
} | ||
return Ok(()); | ||
} else { | ||
eprintln!("{} not found", filename); | ||
} | ||
Err(()) | ||
} | ||
|
||
fn main() { | ||
let base_path_str = std::env::var("EXECUTORCH_INSTALL_PREFIX").unwrap_or_else(|_| "executorch/cmake-out".to_string()); | ||
let lib_path = Path::new(&base_path_str).join("lib"); | ||
println!("cargo:rerun-if-changed=src/sampler.rs"); | ||
println!("cargo:rerun-if-changed=src/tensor.rs"); | ||
|
||
let mut lib_paths: Vec<String> = Vec::new(); | ||
for lib in ET_LIBS { | ||
let filename = match lib { | ||
"extension_module" => format!("lib{}.{}", lib, SO_EXT), | ||
"qnn_executorch_backend" => format!("lib{}.{}", lib, SO_EXT), | ||
_ => format!("lib{}.a", lib), | ||
}; | ||
if lib_path.join(&filename).exists() { | ||
lib_paths.push(filename); | ||
} | ||
let base_path = std::env::var("EXECUTORCH_INSTALL_PREFIX").unwrap_or_else(|_| "executorch/cmake-out".to_string()); | ||
let lib_path = Path::new(&base_path).join("lib"); | ||
|
||
println!("cargo:rustc-link-search=native={}", lib_path.display()); | ||
|
||
assert!(link_lib(&lib_path, "executorch", false).is_ok()); | ||
assert!(link_lib(&lib_path, "extension_module", false).is_ok()); | ||
assert!(link_lib(&lib_path, "extension_data_loader", false).is_ok()); | ||
|
||
// Optimized Kernels | ||
if link_lib(&lib_path, "optimized_native_cpu_ops_lib", true).is_ok() { | ||
assert!(link_lib(&lib_path, "optimized_kernels", false).is_ok()); | ||
assert!(link_lib(&lib_path, "portable_kernels", false).is_ok()); | ||
assert!(link_lib(&lib_path, "cpublas", false).is_ok()); | ||
assert!(link_lib(&lib_path, "eigen_blas", false).is_ok()); | ||
} else { | ||
assert!(link_lib(&lib_path, "portable_ops_lib", true).is_ok()); | ||
assert!(link_lib(&lib_path, "portable_kernels", false).is_ok()); | ||
} | ||
|
||
assert!(!lib_paths.is_empty(), "No lib files found in executorch/cmake-out/lib"); | ||
// Quantized Kernels | ||
if link_lib(&lib_path, "quantized_kernels", false).is_ok() { | ||
assert!(link_lib(&lib_path, "quantized_ops_lib", false).is_ok()); | ||
} | ||
|
||
let mut config = cpp_build::Config::new(); | ||
config.flag("-std=c++17"); | ||
// misc. | ||
let _ = link_lib(&lib_path, "cpuinfo", false); | ||
let _ = link_lib(&lib_path, "pthreadpool", false); | ||
|
||
for lib in lib_paths { | ||
config.flag(&format!("-Wl,--whole-archive -l{} -Wl,--no-whole-archive", lib.trim_start_matches("lib").trim_end_matches(".a"))); | ||
// if lib.ends_with(SO_EXT) { | ||
// // println!("cargo:rustc-link-lib=dylib={}", lib.trim_start_matches("lib").trim_end_matches(SO_EXT)); | ||
// } else { | ||
// config.flag(&format!("--whole-archive -l{} --no-whole-archive", lib)); | ||
// // println!("cargo:rustc-link-lib=static={}", lib.trim_start_matches("lib").trim_end_matches(".a")); | ||
// } | ||
// XNNPACK | ||
if link_lib(&lib_path, "xnnpack_backend", true).is_ok() { | ||
assert!(link_lib(&lib_path, "XNNPACK", false).is_ok()); | ||
} | ||
// println!("cargo:rustc-link-search=native={}", lib_path.display()); | ||
|
||
config.flag(&format!("-L{}", lib_path.display())); | ||
|
||
config.build("src/lib.rs"); | ||
// Vulkan | ||
let _ = link_lib(&lib_path, "vulkan_backend", true); | ||
|
||
// tip rebuild if the library changes | ||
println!("cargo:rerun-if-changed=src/sampler.rs"); | ||
// QNN | ||
let _ = link_lib(&lib_path, "qnn_executorch_backend", true); | ||
|
||
cpp_build::Config::new() | ||
.flag("-std=c++17") | ||
.build("src/lib.rs"); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
#[macro_export] | ||
macro_rules! get_arg_as_vec { | ||
($cx:ident, $index:expr, $js_type:ty, $type:ty) => { | ||
$cx.argument::<JsArray>($index)? | ||
.to_vec(&mut $cx)? | ||
.iter() | ||
.map(|value| value.downcast::<$js_type, _>(&mut $cx).unwrap().value(&mut $cx) as $type) | ||
.collect() | ||
}; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,33 +1,90 @@ | ||
use neon::prelude::*; | ||
use neon::types::Finalize; | ||
use cpp::{cpp, cpp_class}; | ||
|
||
cpp! {{ | ||
#include <vector> | ||
#include <executorch/examples/models/llama2/sampler/sampler.h> | ||
#include <executorch/runtime/core/exec_aten/exec_aten.h> | ||
}} | ||
|
||
cpp_class!(pub unsafe struct Sampler as "torch::executor::Sampler"); | ||
|
||
impl Sampler { | ||
pub fn new(vocab_size: i32, temperature: f32, topp: f32, rng_seed: u64) -> Self { | ||
unsafe { | ||
cpp!([vocab_size as "int", temperature as "float", topp as "float", rng_seed as "uint64_t"] -> Sampler as "torch::executor::Sampler" { | ||
return torch::executor::Sampler(vocab_size, temperature, topp, rng_seed); | ||
}) | ||
} | ||
} | ||
|
||
pub fn sample(&self, param : Vec<f32>) -> i32 { | ||
unsafe { | ||
cpp!([self as "torch::executor::Sampler*", param as "std::vector<float>"] -> i32 as "int32_t" { | ||
auto data = new float[param.size()]; | ||
memcpy(data, param.data(), param.size() * sizeof(float)); | ||
auto result = self->sample(data); | ||
delete[] data; | ||
return result; | ||
}) | ||
} | ||
} | ||
pub enum TensorType { | ||
UInt8 = 0, | ||
Int8 = 1, | ||
Int16 = 2, | ||
Int32 = 3, | ||
Int64 = 4, | ||
// Float16 = 5, | ||
Float32 = 6, | ||
Float64 = 7, | ||
Bool = 11, | ||
} | ||
|
||
impl Finalize for Sampler {} | ||
impl From<i32> for TensorType { | ||
fn from(value: i32) -> Self { | ||
match value { | ||
0 => TensorType::UInt8, | ||
1 => TensorType::Int8, | ||
2 => TensorType::Int16, | ||
3 => TensorType::Int32, | ||
4 => TensorType::Int64, | ||
6 => TensorType::Float32, | ||
7 => TensorType::Float64, | ||
11 => TensorType::Bool, | ||
_ => panic!("Invalid dtype"), | ||
} | ||
} | ||
} | ||
|
||
cpp_class!(unsafe struct AtenTensor as "exec_aten::Tensor"); | ||
|
||
impl AtenTensor { | ||
fn new<T>(dtype: TensorType, dim: i64, shape: *mut i32, data: *mut T) -> Self { | ||
let dtype_num = dtype as i32; | ||
unsafe { | ||
cpp!([dtype_num as "int32_t", dim as "ssize_t", shape as "int32_t*", data as "void*"] -> AtenTensor as "exec_aten::Tensor" { | ||
auto tensor_impl = new exec_aten::TensorImpl( | ||
static_cast<exec_aten::ScalarType>(dtype_num), | ||
dim, | ||
shape, | ||
data | ||
); | ||
return exec_aten::Tensor(tensor_impl); | ||
}) | ||
} | ||
} | ||
|
||
fn dim(&self) -> i64 { | ||
unsafe { | ||
cpp!([self as "exec_aten::Tensor*"] -> i64 as "ssize_t" { | ||
return self->dim(); | ||
}) | ||
} | ||
} | ||
} | ||
|
||
pub struct Tensor<T> { | ||
tensor: AtenTensor, | ||
data: Vec<T>, | ||
shape: Vec<i32>, | ||
} | ||
|
||
impl<T> Tensor<T> { | ||
pub fn dim(&self) -> i64 { | ||
self.tensor.dim() | ||
} | ||
} | ||
|
||
impl<T> Finalize for Tensor<T> {} | ||
|
||
// u8 | ||
impl Tensor<u8> { | ||
pub fn new(mut shape: Vec<i32>, mut data: Vec<u8>) -> Self { | ||
let dim = shape.len() as i64; | ||
let shape_ptr = shape.as_mut_ptr(); | ||
let data_ptr = data.as_mut_ptr(); | ||
let tensor = AtenTensor::new(TensorType::UInt8, dim, shape_ptr, data_ptr); | ||
Tensor { | ||
tensor, | ||
data, | ||
shape, | ||
} | ||
} | ||
} |