From 0c13f37bec2f048ce1af6d68dca14d28ecfdad5a Mon Sep 17 00:00:00 2001 From: Jan Funke Date: Sun, 3 Sep 2023 17:16:00 -0400 Subject: [PATCH] Bugfix: load checkpoint also if "model_state_dict" is not a key --- gunpowder/torch/nodes/predict.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gunpowder/torch/nodes/predict.py b/gunpowder/torch/nodes/predict.py index d426d2a8..3e5ba8f1 100644 --- a/gunpowder/torch/nodes/predict.py +++ b/gunpowder/torch/nodes/predict.py @@ -104,7 +104,7 @@ def start(self): if "model_state_dict" in checkpoint: self.model.load_state_dict(checkpoint["model_state_dict"]) else: - self.model.load_state_dict() + self.model.load_state_dict(checkpoint) def predict(self, batch, request): inputs = self.get_inputs(batch)