From 856c579ec9bc2408c5990df761e2dd5e1dd1d1aa Mon Sep 17 00:00:00 2001 From: Oliver Lemke Date: Thu, 14 Nov 2024 13:25:07 +0100 Subject: [PATCH] Fix torch.Tensor copy construction warning --- typhon/retrieval/qrnn/models/pytorch/common.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/typhon/retrieval/qrnn/models/pytorch/common.py b/typhon/retrieval/qrnn/models/pytorch/common.py index 17300fa0..ae8b1ec6 100644 --- a/typhon/retrieval/qrnn/models/pytorch/common.py +++ b/typhon/retrieval/qrnn/models/pytorch/common.py @@ -51,7 +51,7 @@ def load_model(f, quantiles): Returns: The loaded pytorch model. """ - model = torch.load(f) + model = torch.load(f, weights_only=False) return model @@ -92,8 +92,8 @@ class BatchedDataset(Dataset): def __init__(self, training_data, batch_size): x, y = training_data - self.x = torch.tensor(x, dtype=torch.float) - self.y = torch.tensor(y, dtype=torch.float) + self.x = x if isinstance(x, torch.Tensor) else torch.tensor(x, dtype=torch.float) + self.y = y if isinstance(y, torch.Tensor) else torch.tensor(y, dtype=torch.float) self.batch_size = batch_size def __len__(self):