diff --git a/scripts/infer.py b/scripts/infer.py index 66da415..509c4c5 100644 --- a/scripts/infer.py +++ b/scripts/infer.py @@ -138,8 +138,8 @@ def __init__(self, config_file="./configs/infer.yaml", **override): bundle_root = parser.get_parsed_content("bundle_root") ts = os.path.getmtime(config_file) self.model.image_encoder.encoder = TRTWrapper( - f"{bundle_root}/image_encoder", self.model.image_encoder.encoder, + f"{bundle_root}/image_encoder", precision="fp16", build_args={ "builder_optimization_level": 5, @@ -148,8 +148,8 @@ def __init__(self, config_file="./configs/infer.yaml", **override): timestamp=ts, ) self.model.class_head = TRTWrapper( - f"{bundle_root}/class_head", self.model.class_head, + f"{bundle_root}/class_head", precision="fp16", build_args={ "builder_optimization_level": 5,