From 9329721eae793279a6fffdfc61a9e45a24156c5f Mon Sep 17 00:00:00 2001 From: Oliver Lemke Date: Thu, 14 Nov 2024 13:36:41 +0100 Subject: [PATCH 1/6] Fix saving of models Mashing the pickle and the model into the same binary file doesn't work anymore. Probably due to upstream changes. Nowi, two separate files are created (".pkl" and ".model"). --- typhon/retrieval/qrnn/qrnn.py | 16 ++++++++++------ typhon/tests/retrieval/qrnn/test_qrnn.py | 7 ++++--- 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/typhon/retrieval/qrnn/qrnn.py b/typhon/retrieval/qrnn/qrnn.py index 5be1cb08..0208104c 100644 --- a/typhon/retrieval/qrnn/qrnn.py +++ b/typhon/retrieval/qrnn/qrnn.py @@ -600,11 +600,14 @@ def load(path): The loaded QRNN object. """ - with open(path, 'rb') as f: + with open(path + ".pkl", 'rb') as f: qrnn = pickle.load(f) + + with open(path + ".model", 'rb') as f: backend = importlib.import_module(qrnn.backend) model = backend.load_model(f, qrnn.quantiles) qrnn.model = model + return qrnn def save(self, path): @@ -621,11 +624,12 @@ def save(self, path): store the model. """ - f = open(path, "wb") - pickle.dump(self, f) - backend = importlib.import_module(self.backend) - backend.save_model(f, self.model) - f.close() + with open(path + ".pkl", 'wb') as f: + pickle.dump(self, f) + + with open(path + ".model", 'wb') as f: + backend = importlib.import_module(self.backend) + backend.save_model(f, self.model) def __getstate__(self): diff --git a/typhon/tests/retrieval/qrnn/test_qrnn.py b/typhon/tests/retrieval/qrnn/test_qrnn.py index 327b55a1..e32404ef 100644 --- a/typhon/tests/retrieval/qrnn/test_qrnn.py +++ b/typhon/tests/retrieval/qrnn/test_qrnn.py @@ -87,9 +87,10 @@ def test_save_qrnn(self, backend): """ set_backend(backend) qrnn = QRNN(self.x_train.shape[1], np.linspace(0.05, 0.95, 10)) - f = tempfile.NamedTemporaryFile() - qrnn.save(f.name) - qrnn_loaded = QRNN.load(f.name) + with tempfile.TemporaryDirectory() as d: + f = os.path.join(d, "qrnn") + qrnn.save(f) + qrnn_loaded = QRNN.load(f) x_pred = qrnn.predict(self.x_train) x_pred_loaded = qrnn.predict(self.x_train) From 2beb0f3702b065f5646fedca4ec4dc2e6d7da3f2 Mon Sep 17 00:00:00 2001 From: Oliver Lemke Date: Thu, 14 Nov 2024 13:37:12 +0100 Subject: [PATCH 2/6] Check for pytorch first, then keras keras >=3 is not supported, better to use pytorch by default instead. --- typhon/retrieval/qrnn/qrnn.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/typhon/retrieval/qrnn/qrnn.py b/typhon/retrieval/qrnn/qrnn.py index 0208104c..0e5ecaa0 100644 --- a/typhon/retrieval/qrnn/qrnn.py +++ b/typhon/retrieval/qrnn/qrnn.py @@ -23,12 +23,12 @@ ################################################################################ try: - import typhon.retrieval.qrnn.models.keras as keras - backend = keras + import typhon.retrieval.qrnn.models.pytorch as pytorch + backend = pytorch except Exception as e: try: - import typhon.retrieval.qrnn.models.pytorch as pytorch - backend = pytorch + import typhon.retrieval.qrnn.models.keras as keras + backend = keras except: raise Exception("Couldn't import neither Keras nor Pytorch " "one of them must be available to use the QRNN" From 0fd88b030a443274c378632c57fd8a2a2b028ea9 Mon Sep 17 00:00:00 2001 From: Oliver Lemke Date: Thu, 14 Nov 2024 13:22:17 +0100 Subject: [PATCH 3/6] Add version check for keras, we only support 2 --- typhon/retrieval/qrnn/backends/keras.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/typhon/retrieval/qrnn/backends/keras.py b/typhon/retrieval/qrnn/backends/keras.py index 4a7830fc..f8d83188 100644 --- a/typhon/retrieval/qrnn/backends/keras.py +++ b/typhon/retrieval/qrnn/backends/keras.py @@ -18,6 +18,8 @@ from keras.models import Sequential, clone_model, Model from keras.layers import Dense, Activation, Dropout from keras.optimizers import SGD + if int(keras.__version__.split('.')[0]) != 2: + raise ImportError() except ImportError: raise ImportError( "Could not import the required Keras modules. The QRNN " From 158f586242c48bc093adb81d4981d94ea6d86aec Mon Sep 17 00:00:00 2001 From: Oliver Lemke Date: Thu, 14 Nov 2024 13:23:04 +0100 Subject: [PATCH 4/6] Add try catch with version check in model file as well --- typhon/retrieval/qrnn/models/keras.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/typhon/retrieval/qrnn/models/keras.py b/typhon/retrieval/qrnn/models/keras.py index e049809a..0131bb8d 100644 --- a/typhon/retrieval/qrnn/models/keras.py +++ b/typhon/retrieval/qrnn/models/keras.py @@ -7,11 +7,18 @@ """ import logging import numpy as np -import keras -from keras.models import Sequential -from keras.layers import Dense, Activation, deserialize -from keras.optimizers import SGD -import keras.backend as K +try: + import keras + from keras.models import Sequential + from keras.layers import Dense, Activation, deserialize + from keras.optimizers import SGD + import keras.backend as K + if int(keras.__version__.split('.')[0]) != 2: + raise ImportError() +except ImportError: + raise ImportError( + "Could not import the required Keras modules. The QRNN " + "implementation was developed for use with Keras version 2.0.9.") def save_model(f, model): From b6be86f116a9a5d2761e489b78d1e70e5b66b2ec Mon Sep 17 00:00:00 2001 From: Oliver Lemke Date: Thu, 14 Nov 2024 13:23:58 +0100 Subject: [PATCH 5/6] Comment out keras backend in tests Version 3 doesn't work and version 2 is not installable in Python >=3.12. --- typhon/tests/retrieval/qrnn/test_qrnn.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/typhon/tests/retrieval/qrnn/test_qrnn.py b/typhon/tests/retrieval/qrnn/test_qrnn.py index e32404ef..a3822020 100644 --- a/typhon/tests/retrieval/qrnn/test_qrnn.py +++ b/typhon/tests/retrieval/qrnn/test_qrnn.py @@ -14,12 +14,12 @@ # backends = [] -try: - import typhon.retrieval.qrnn.models.keras - - backends += ["keras"] -except: - pass +# try: +# import typhon.retrieval.qrnn.models.keras +# +# backends += ["keras"] +# except: +# pass try: import typhon.retrieval.qrnn.models.pytorch From 856c579ec9bc2408c5990df761e2dd5e1dd1d1aa Mon Sep 17 00:00:00 2001 From: Oliver Lemke Date: Thu, 14 Nov 2024 13:25:07 +0100 Subject: [PATCH 6/6] Fix torch.Tensor copy construction warning --- typhon/retrieval/qrnn/models/pytorch/common.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/typhon/retrieval/qrnn/models/pytorch/common.py b/typhon/retrieval/qrnn/models/pytorch/common.py index 17300fa0..ae8b1ec6 100644 --- a/typhon/retrieval/qrnn/models/pytorch/common.py +++ b/typhon/retrieval/qrnn/models/pytorch/common.py @@ -51,7 +51,7 @@ def load_model(f, quantiles): Returns: The loaded pytorch model. """ - model = torch.load(f) + model = torch.load(f, weights_only=False) return model @@ -92,8 +92,8 @@ class BatchedDataset(Dataset): def __init__(self, training_data, batch_size): x, y = training_data - self.x = torch.tensor(x, dtype=torch.float) - self.y = torch.tensor(y, dtype=torch.float) + self.x = x if isinstance(x, torch.Tensor) else torch.tensor(x, dtype=torch.float) + self.y = y if isinstance(y, torch.Tensor) else torch.tensor(y, dtype=torch.float) self.batch_size = batch_size def __len__(self):