Skip to content

Commit

Permalink
all inner TensorRT calls now go through runtime
Browse files Browse the repository at this point in the history
  • Loading branch information
gerwin3 committed Aug 8, 2023
1 parent b91a1b3 commit 5f872f1
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 31 deletions.
25 changes: 13 additions & 12 deletions crates/async-tensorrt/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ impl Builder {
///
/// [TensorRT documentation](https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/classnvinfer1_1_1_i_builder.html#a68a8b59fbf86e42762b7087e6ffe6fb4)
#[inline(always)]
pub fn add_optimization_profile(&mut self) -> Result<()> {
self.inner.add_optimization_profile()
pub async fn add_optimization_profile(&mut self) -> Result<()> {
Future::new(|| self.inner.add_optimization_profile()).await
}

/// Create a new optimization profile.
Expand All @@ -54,8 +54,9 @@ impl Builder {
/// may or may not actually affect the building process later.
///
/// [TensorRT documentation](https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/classnvinfer1_1_1_i_builder.html#a68a8b59fbf86e42762b7087e6ffe6fb4)
pub fn with_optimization_profile(mut self) -> Result<Self> {
self.add_optimization_profile()?;
#[inline(always)]
pub async fn with_optimization_profile(mut self) -> Result<Self> {
self.add_optimization_profile().await?;
Ok(self)
}

Expand All @@ -67,8 +68,8 @@ impl Builder {
///
/// A [`BuilderConfig`] that can later be passed to `build_serialized_network`.
#[inline(always)]
pub fn config(&mut self) -> BuilderConfig {
self.inner.config()
pub async fn config(&mut self) -> BuilderConfig {
Future::new(|| self.inner.config()).await
}

/// Create a network definition object.
Expand All @@ -79,11 +80,11 @@ impl Builder {
///
/// * `flags` - Flags for specifying network properties.
#[inline(always)]
pub fn network_definition(
pub async fn network_definition(
&mut self,
flags: NetworkDefinitionCreationFlags,
) -> NetworkDefinition {
self.inner.network_definition(flags)
Future::new(|| self.inner.network_definition(flags)).await
}

/// Builds and serializes a network for the provided [`crate::ffi::network::NetworkDefinition`]
Expand Down Expand Up @@ -111,15 +112,15 @@ impl Builder {
///
/// [TensorRT documentation](https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/classnvinfer1_1_1_i_builder.html#ab09433c57e3ef02f7aad672ec4235ea4)
#[inline(always)]
pub fn platform_has_fast_int8(&self) -> bool {
self.inner.platform_has_fast_int8()
pub async fn platform_has_fast_int8(&self) -> bool {
Future::new(|| self.inner.platform_has_fast_int8()).await
}

/// Determine whether the platform has fast native FP16.
///
/// [TensorRT documentation](https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/classnvinfer1_1_1_i_builder.html#a6e42dd3ecb449ba54ffb823685a7ac47)
#[inline(always)]
pub fn platform_has_fast_fp16(&self) -> bool {
self.inner.platform_has_fast_fp16()
pub async fn platform_has_fast_fp16(&self) -> bool {
Future::new(|| self.inner.platform_has_fast_fp16()).await
}
}
30 changes: 15 additions & 15 deletions crates/async-tensorrt/src/engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ impl Engine {
///
/// [TensorRT documentation](https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/classnvinfer1_1_1_i_cuda_engine.html#af2018924cbea2fa84808040e60c58405)
#[inline(always)]
pub fn num_io_tensors(&self) -> usize {
self.inner.num_io_tensors()
pub async fn num_io_tensors(&self) -> usize {
Future::new(|| self.inner.num_io_tensors()).await
}

/// Retrieve the name of an IO tensor.
Expand All @@ -49,8 +49,8 @@ impl Engine {
///
/// * `io_tensor_index` - IO tensor index.
#[inline(always)]
pub fn io_tensor_name(&self, io_tensor_index: usize) -> String {
self.inner.io_tensor_name(io_tensor_index)
pub async fn io_tensor_name(&self, io_tensor_index: usize) -> String {
Future::new(|| self.inner.io_tensor_name(io_tensor_index)).await
}

/// Get the shape of a tensor.
Expand All @@ -61,8 +61,8 @@ impl Engine {
///
/// * `tensor_name` - Tensor name.
#[inline(always)]
pub fn tensor_shape(&self, tensor_name: &str) -> Vec<usize> {
self.inner.tensor_shape(tensor_name)
pub async fn tensor_shape(&self, tensor_name: &str) -> Vec<usize> {
Future::new(|| self.inner.tensor_shape(tensor_name)).await
}

/// Get the IO mode of a tensor.
Expand All @@ -73,8 +73,8 @@ impl Engine {
///
/// * `tensor_name` - Tensor name.
#[inline(always)]
pub fn tensor_io_mode(&self, tensor_name: &str) -> TensorIoMode {
self.inner.tensor_io_mode(tensor_name)
pub async fn tensor_io_mode(&self, tensor_name: &str) -> TensorIoMode {
Future::new(|| self.inner.tensor_io_mode(tensor_name)).await
}
}

Expand Down Expand Up @@ -205,13 +205,13 @@ mod tests {
#[tokio::test]
async fn test_engine_tensor_info() {
let engine = simple_engine!();
assert_eq!(engine.num_io_tensors(), 2);
assert_eq!(engine.io_tensor_name(0), "X");
assert_eq!(engine.io_tensor_name(1), "Y");
assert_eq!(engine.tensor_io_mode("X"), TensorIoMode::Input);
assert_eq!(engine.tensor_io_mode("Y"), TensorIoMode::Output);
assert_eq!(engine.tensor_shape("X"), &[1, 2]);
assert_eq!(engine.tensor_shape("Y"), &[2, 3]);
assert_eq!(engine.num_io_tensors().await, 2);
assert_eq!(engine.io_tensor_name(0).await, "X");
assert_eq!(engine.io_tensor_name(1).await, "Y");
assert_eq!(engine.tensor_io_mode("X").await, TensorIoMode::Input);
assert_eq!(engine.tensor_io_mode("Y").await, TensorIoMode::Output);
assert_eq!(engine.tensor_shape("X").await, &[1, 2]);
assert_eq!(engine.tensor_shape("Y").await, &[2, 3]);
}

#[tokio::test]
Expand Down
4 changes: 3 additions & 1 deletion crates/async-tensorrt/src/ffi/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,9 @@ mod tests {
async fn test_parser_parses_onnx_file() {
let simple_onnx_file = simple_onnx_file!();
let mut builder = Builder::new().await;
let network = builder.network_definition(NetworkDefinitionCreationFlags::ExplicitBatchSize);
let network = builder
.network_definition(NetworkDefinitionCreationFlags::ExplicitBatchSize)
.await;
assert!(
Parser::parse_network_definition_from_file(network, &simple_onnx_file.path()).is_ok()
);
Expand Down
8 changes: 5 additions & 3 deletions crates/async-tensorrt/src/tests/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@ macro_rules! simple_network {
let mut builder = $crate::Builder::new()
.await
.with_optimization_profile()
.await
.unwrap();
let network =
builder.network_definition($crate::NetworkDefinitionCreationFlags::ExplicitBatchSize);
let network = builder
.network_definition($crate::NetworkDefinitionCreationFlags::ExplicitBatchSize)
.await;
let network =
$crate::Parser::parse_network_definition_from_file(network, &simple_onnx_file.path())
.unwrap();
Expand All @@ -17,7 +19,7 @@ macro_rules! simple_network {
macro_rules! simple_network_plan {
() => {{
let (mut builder, mut network) = $crate::tests::utils::simple_network!();
let builder_config = builder.config();
let builder_config = builder.config().await;
builder
.build_serialized_network(&mut network, builder_config)
.await
Expand Down

0 comments on commit 5f872f1

Please sign in to comment.