Skip to content

Commit

Permalink
Merge pull request #21 from mstallmo/tensorrt-version
Browse files Browse the repository at this point in the history
Expand build script to be more flexible with respect to TensorRT version
  • Loading branch information
mstallmo authored Sep 23, 2020
2 parents 6013dbd + fc26b68 commit 2565b27
Show file tree
Hide file tree
Showing 11 changed files with 157 additions and 117 deletions.
3 changes: 3 additions & 0 deletions .gitattributes
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
tensorrt-sys/3rdParty/TensorRT-5.1.5/libnvinfer_plugin_static.a filter=lfs diff=lfs merge=lfs -text
tensorrt-sys/3rdParty/TensorRT-5.1.5/libnvinfer_static.a filter=lfs diff=lfs merge=lfs -text
tensorrt-sys/3rdParty/TensorRT-5.1.5/libnvparsers_static.a filter=lfs diff=lfs merge=lfs -text
Git LFS file not shown
3 changes: 3 additions & 0 deletions tensorrt-sys/3rdParty/TensorRT-5.1.5/libnvinfer_static.a
Git LFS file not shown
3 changes: 3 additions & 0 deletions tensorrt-sys/3rdParty/TensorRT-5.1.5/libnvparsers_static.a
Git LFS file not shown
13 changes: 13 additions & 0 deletions tensorrt-sys/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,19 @@ build = "build.rs"
repository = "https://github.com/mstallmo/tensorrt-rs"
description = "Low level wrapper around Nvidia's TensorRT library"

[features]
default = ["trt-515"]

trt-515 = ["cuda-101"]

trt-713 = []

cuda-101 = []

cuda-102 = []

cuda-110 = []

[dependencies]
libc = "0.2.62"

Expand Down
12 changes: 9 additions & 3 deletions tensorrt-sys/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,16 @@ CMake > 3.10

TensorRT-sys' bindings depends on TensorRT 5.1.5 for the bindings to work correctly. While other versions of
TensorRT *may* work with the bindings there are no guarantees as functions that are boudn to may have been depricated,
removed, or changed in future verions of TensorRT.
removed, or changed in future versions of TensorRT.

The prerequisites enumerated above are expected to be installed in their default location on Linux
(/usr/lib/x86_64-linux-gnu/)
The prerequisites enumerated above are expected to be installed in their default location on Linux. See the [nvidia
documentation](https://docs.nvidia.com/deeplearning/tensorrt/install-guide/index.html#installing) around TensorRT for
further install information.

__Note:__ The tarball installation method described in the TesnorRT documentation is likely to cause major headaches with
getting everything to link correctly. It is highly recommended to use the package manager method if possible.

Windows support is not currently supported but should be coming soon!

### Support Matrix for TensorRT Classes
Anything not listed below currently does not have any support.
Expand Down
116 changes: 109 additions & 7 deletions tensorrt-sys/build.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,115 @@
use cmake::Config;
use std::process::Command;
use std::string::String;
use std::process::Stdio;
use std::path::PathBuf;

fn main() {
let dst = Config::new("trt-sys").build();
fn get_shared_lib_link_path(library_name: &str) -> Option<PathBuf> {
match get_all_possible_link_paths(library_name) {
Some(all_link_paths) => {
for line in all_link_paths.lines() {
if line.ends_with(&format!("{}.so", library_name)) {
let link_path = line.split("=>").collect::<Vec<&str>>().last().unwrap().to_owned();
println!("link path: {}", link_path);
return Some(PathBuf::from(link_path.trim().to_owned()));
}
}
None
}
None => {
None
}
}
}

fn get_all_possible_link_paths(library_name: &str) -> Option<String> {
let mut ld_config = Command::new("ldconfig").arg("-p").stdout(Stdio::piped()).spawn().expect("Failed to run ldconfig");

println!("cargo:rustc-link-search=native={}", dst.display());
if let Some(ld_config_output) = ld_config.stdout.take() {
let grep_config = Command::new("grep").arg(library_name).stdin(ld_config_output).stdout(Stdio::piped()).spawn().unwrap();
let grep_stdout = grep_config.wait_with_output().unwrap();
ld_config.wait().unwrap();
Some(String::from_utf8(grep_stdout.stdout).unwrap())
} else {
None
}
}

#[cfg(feature = "trt-515")]
fn configuration(full_library_path: &PathBuf) {
if full_library_path.to_str().unwrap().ends_with("5.1.5") {
let mut config = Config::new("trt-sys");
let dst = config.define("TRT_VERSION", "5.1.5").build();
println!("cargo:rustc-link-search=native={}", dst.display());
println!("cargo:rustc-flags=-l dylib=nvinfer");
println!("cargo:rustc-flags=-l dylib=nvparsers");
println!("cargo:rustc-flags=-L {}/3rdParty/TensorRT-5.1.5", env!("CARGO_MANIFEST_DIR"));
println!("cargo:rustc-flags=-l static=nvinfer_plugin_static");
cuda_configuration();
} else {
panic!("Invalid nvinfer version found. Expected: libnvinfer.so.5.1.5, Found: {}", full_library_path.to_str().unwrap());
}
}

#[cfg(feature = "trt-713")]
fn configuration(full_library_path: &PathBuf) {
if full_library_path.to_str().unwrap().ends_with("7.1.3") {
let mut config = Config::new("trt-sys");
let dst = config.define("TRT_VERSION", "7.1.3").build();
println!("cargo:rustc-link-search=native={}", dst.display());
println!("cargo:rustc-flags=-l dylib=nvinfer");
println!("cargo:rustc-flags=-l dylib=nvparsers");
cuda_configuration();
} else {
panic!("Invalid nvinfer version found. Expected: libnvinfer.so.7.1.3, Found: {}", full_library_path.to_str().unwrap());
}
}

#[cfg(feature = "cuda-101")]
fn cuda_configuration() {
println!("cargo:rustc-flags=-L /usr/local/cuda-10.1/lib64");
println!("cargo:rustc-flags=-l dylib=cudart");
println!("cargo:rustc-flags=-l dylib=cublas");
println!("cargo:rustc-flags=-l dylib=cublasLt");
println!("cargo:rustc-flags=-l dylib=cudnn");
}

#[cfg(feature = "cuda-102")]
fn cuda_configuration() {
println!("cargo:rustc-flags=-L /usr/local/cuda-10.2/lib64");
println!("cargo:rustc-flags=-l dylib=cudart");
}

#[cfg(feature = "cuda-110")]
fn cuda_configuration() {
println!("cargo:rustc-flags=-L /usr/local/cuda-11.0/lib64");
println!("cargo:rustc-flags=-l dylib=cudart");
}

// Not sure if I love this solution but I think it's relatively robust enough for now on Unix systems.
// Still have to thoroughly test what happens with a TRT library installed that's not done by the
// dpkg. It's possible that we'll just have to fall back to only supporting one system library and assuming that
// the user has the correct library installed and is viewable via ldconfig.
//
// Hopefully something like this will work for Windows installs as well, not having a default library
// install location will make that significantly harder.
fn main() {
println!("cargo:rustc-link-lib=static=trt-sys");
println!("cargo:rustc-flags=-l dylib=stdc++");
println!("cargo:rustc-flags=-l dylib=nvinfer");
println!("cargo:rustc-flags=-l dylib=nvparsers");
println!("cargo:rustc-flags=-L /usr/local/cuda/lib64");
println!("cargo:rustc-flags=-l dylib=cudart");

match get_shared_lib_link_path("libnvinfer") {
Some(link_path) => {
match std::fs::read_link(link_path) {
Ok(full_library_path) => {
configuration(&full_library_path);
},
Err(_) => {
panic!("libnvinfer.so not found! See https://docs.nvidia.com/deeplearning/tensorrt/archives/tensorrt-515/tensorrt-install-guide/index.html for install instructions");
}
}
},
None => {
panic!("libnvinfer.so not found! See https://docs.nvidia.com/deeplearning/tensorrt/archives/tensorrt-515/tensorrt-install-guide/index.html for install instructions");
}
}
}
98 changes: 7 additions & 91 deletions tensorrt-sys/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ mod tests {
use std::io::prelude::*;
use std::os::raw::{c_char, c_int, c_void};

#[cfg(feature = "trt-515")]
#[test]
fn tensorrt_version() {
let mut c_buf = Vec::<c_char>::with_capacity(6);
Expand All @@ -28,97 +29,12 @@ mod tests {
assert_eq!("5.1.5", c_str.to_str().unwrap());
}

// #[test]
// fn cuda_runtime() {
// let logger = unsafe { create_logger() };
// let runtime = unsafe { create_infer_runtime(logger) };
//
// let mut f = File::open("resnet34-unet-Aug25-07-25-16-best.engine").unwrap();
// let mut buffer = Vec::new();
// f.read_to_end(&mut buffer).unwrap();
// let engine = unsafe {
// deserialize_cuda_engine(
// runtime,
// buffer.as_ptr() as *const c_void,
// buffer.len() as u64,
// )
// };
//
// let bindingNameCStr = unsafe {
// let bindingName = get_binding_name(engine, 0);
// CStr::from_ptr(bindingName)
// };
// println!(
// "Binding name for index {} is {}",
// 0,
// bindingNameCStr.to_str().unwrap()
// );
//
// let execution_context = unsafe { engine_create_execution_context(engine) };
// unsafe { context_set_name(execution_context, CString::new("Mason").unwrap().as_ptr()) };
// let context_name_cstr = unsafe {
// let context_name = context_get_name(execution_context);
// CStr::from_ptr(context_name)
// };
// println!("Context name is {}", context_name_cstr.to_str().unwrap());
//
// let input_binding =
// unsafe { get_binding_index(engine, CString::new("data").unwrap().as_ptr()) };
// println!("Binding index for data is {}", input_binding);
//
// unsafe {
// destroy_excecution_context(execution_context);
// destroy_cuda_engine(engine);
// destroy_infer_runtime(runtime);
// delete_logger(logger);
// }
// }
//
// #[test]
// fn host_memory() {
// let logger = unsafe { create_logger() };
// let runtime = unsafe { create_infer_runtime(logger) };
//
// let mut f = File::open("resnet34-unet-Aug25-07-25-16-best.engine").unwrap();
// let mut buffer = Vec::new();
// f.read_to_end(&mut buffer).unwrap();
// let engine = unsafe {
// deserialize_cuda_engine(
// runtime,
// buffer.as_ptr() as *const c_void,
// buffer.len() as u64,
// )
// };
//
// let host_memory = unsafe { engine_serialize(engine) };
// let memory_sise = unsafe { host_memory_get_size(host_memory) };
// println!("Host Memory Size of Engine: {}", memory_sise);
// }

#[cfg(feature = "trt-713")]
#[test]
fn uff_parser() {
let parser = unsafe { uffparser_create_uff_parser() };
let mut d_vec = vec![3, 256, 256];
let mut type_vec = vec![1, 0, 0];
let dims = unsafe {
crate::create_dims(
3,
d_vec.as_mut_ptr() as *mut c_int,
type_vec.as_mut_ptr() as *mut c_int,
)
};
let input_ret = unsafe {
uffparser_register_input(parser, CString::new("input").unwrap().as_ptr(), dims)
};
assert_eq!(input_ret, true);

let output_ret = unsafe {
uffparser_register_output(parser, CString::new("sigmoid/Sigmoid").unwrap().as_ptr())
};
assert_eq!(output_ret, true);

unsafe {
uffparser_destroy_uff_parser(parser);
}
fn tensorrt_version() {
let mut c_buf = Vec::<c_char>::with_capacity(6);
unsafe { get_tensorrt_version(c_buf.as_mut_ptr()) };
let c_str = unsafe { CStr::from_ptr(c_buf.as_ptr()) };
assert_eq!("7.1.3", c_str.to_str().unwrap());
}
}
18 changes: 5 additions & 13 deletions tensorrt-sys/trt-sys/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,36 +1,28 @@
cmake_minimum_required(VERSION 3.10)
project(LibTRT LANGUAGES CXX CUDA)

if(${TRT_VERSION} MATCHES "5.1.5")
message(STATUS "TRT version is ${TRT_VERSION}")
add_definitions(-DTRT_VERSION="${TRT_VERSION}")
endif()

set(CMAKE_CXX_STANDARD 17)

set(CMAKE_CXX_FLAGS "-O3 -Wall -Wextra -Werror -Wno-unknown-pragmas")

file(GLOB source_files
"TRTLogger/*.h"
"TRTLogger/*.cpp"
"TRTRuntime/*.h"
"TRTRuntime/*cpp"
"TRTCudaEngine/*.h"
"TRTCudaEngine/*.cpp"
"TRTContext/*.h"
"TRTContext/*.cpp"
"TRTUffParser/*.h"
"TRTUffParser/*.cpp"
"TRTDims/*.h"
"TRTDims/*.cpp"
"TRTBuilder/*.h"
"TRTBuilder/*.cpp"
"TRTNetworkDefinition/*.h"
"TRTNetworkDefinition/*.cpp"
"TRTHostMemory/*.h"
"TRTHostMemory/*.cpp"
"*.h"
)

find_library(CUDART_LIBRARY cudart ${CMAKE_CUDA_IMPLICIT_LINK_DIRECTORIES})

add_library(trt-sys STATIC ${source_files})
target_link_libraries(trt-sys PRIVATE nvinfer ${CUDART_LIBRARY})
include_directories(${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})

install(TARGETS trt-sys DESTINATION .)
1 change: 0 additions & 1 deletion tensorrt-sys/trt-sys/TRTContext/TRTContext.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
//
// Created by mason on 9/17/19.
//
#include <cstdlib>
#include <memory>
#include <cuda_runtime.h>
#include "NvInfer.h"
Expand Down
4 changes: 2 additions & 2 deletions tensorrt/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ description = "Rust library for using Nvidia's TensorRT deep learning accelerati

[dependencies]
# Uncomment when working locally
#tensorrt-sys = {version = "0.2", path="../tensorrt-sys"}
tensorrt-sys = "0.2.1"
tensorrt-sys = {version = "0.2", path="../tensorrt-sys"}
#tensorrt-sys = "0.2.1"
ndarray = "0.13.1"
ndarray-image = "0.2.1"
image = "0.23.9"
Expand Down

0 comments on commit 2565b27

Please sign in to comment.