Skip to content

Commit

Permalink
Adjusted for TRT wrapper refactoring
Browse files Browse the repository at this point in the history
Signed-off-by: Boris Fomitchev <[email protected]>
  • Loading branch information
borisfom committed Aug 18, 2024
1 parent 6053338 commit 2ec59eb
Showing 1 changed file with 14 additions and 15 deletions.
29 changes: 14 additions & 15 deletions scripts/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@
from .train import CONFIG
from .utils.trans_utils import VistaPostTransform, get_largest_connected_component_point

TRTWrapper, TRT_AVAILABLE = optional_import(
"monai.networks.trt_wrapper", name="TRTWrapper"
trt_wrap, TRT_AVAILABLE = optional_import(
"monai.networks", name="trt_wrap"
)

rearrange, _ = optional_import("einops", name="rearrange")
Expand Down Expand Up @@ -137,25 +137,24 @@ def __init__(self, config_file="./configs/infer.yaml", **override):
if self.trt and TRT_AVAILABLE:
bundle_root = parser.get_parsed_content("bundle_root")
ts = os.path.getmtime(config_file)
self.model.image_encoder.encoder = TRTWrapper(
self.model.image_encoder.encoder,
f"{bundle_root}/image_encoder",
precision="fp16",
build_args={
trt_args = {
"precision": "fp16",
"build_args": {
"builder_optimization_level": 5,
"precision_constraints": "obey",
},
timestamp=ts,
"timestamp": ts
}

trt_wrap(
self.model.image_encoder.encoder,
f"{bundle_root}/image_encoder",
args=trt_args,
)
self.model.class_head = TRTWrapper(
trt_wrap(
self.model.class_head,
f"{bundle_root}/class_head",
precision="fp16",
build_args={
"builder_optimization_level": 5,
"precision_constraints": "obey",
},
timestamp=ts,
args=trt_args,
)
return

Expand Down

0 comments on commit 2ec59eb

Please sign in to comment.