Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix issue due to newer pytorch/keras versions #429

Merged
merged 6 commits into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions typhon/retrieval/qrnn/backends/keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from keras.models import Sequential, clone_model, Model
from keras.layers import Dense, Activation, Dropout
from keras.optimizers import SGD
if int(keras.__version__.split('.')[0]) != 2:
raise ImportError()
except ImportError:
raise ImportError(
"Could not import the required Keras modules. The QRNN "
Expand Down
17 changes: 12 additions & 5 deletions typhon/retrieval/qrnn/models/keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,18 @@
"""
import logging
import numpy as np
import keras
from keras.models import Sequential
from keras.layers import Dense, Activation, deserialize
from keras.optimizers import SGD
import keras.backend as K
try:
import keras
from keras.models import Sequential
from keras.layers import Dense, Activation, deserialize
from keras.optimizers import SGD
import keras.backend as K
if int(keras.__version__.split('.')[0]) != 2:
raise ImportError()
except ImportError:
raise ImportError(
"Could not import the required Keras modules. The QRNN "
"implementation was developed for use with Keras version 2.0.9.")


def save_model(f, model):
Expand Down
6 changes: 3 additions & 3 deletions typhon/retrieval/qrnn/models/pytorch/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def load_model(f, quantiles):
Returns:
The loaded pytorch model.
"""
model = torch.load(f)
model = torch.load(f, weights_only=False)
return model


Expand Down Expand Up @@ -92,8 +92,8 @@ class BatchedDataset(Dataset):

def __init__(self, training_data, batch_size):
x, y = training_data
self.x = torch.tensor(x, dtype=torch.float)
self.y = torch.tensor(y, dtype=torch.float)
self.x = x if isinstance(x, torch.Tensor) else torch.tensor(x, dtype=torch.float)
self.y = y if isinstance(y, torch.Tensor) else torch.tensor(y, dtype=torch.float)
self.batch_size = batch_size

def __len__(self):
Expand Down
24 changes: 14 additions & 10 deletions typhon/retrieval/qrnn/qrnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@
################################################################################

try:
import typhon.retrieval.qrnn.models.keras as keras
backend = keras
import typhon.retrieval.qrnn.models.pytorch as pytorch
backend = pytorch
except Exception as e:
try:
import typhon.retrieval.qrnn.models.pytorch as pytorch
backend = pytorch
import typhon.retrieval.qrnn.models.keras as keras
backend = keras
except:
raise Exception("Couldn't import neither Keras nor Pytorch "
"one of them must be available to use the QRNN"
Expand Down Expand Up @@ -600,11 +600,14 @@ def load(path):

The loaded QRNN object.
"""
with open(path, 'rb') as f:
with open(path + ".pkl", 'rb') as f:
qrnn = pickle.load(f)

with open(path + ".model", 'rb') as f:
backend = importlib.import_module(qrnn.backend)
model = backend.load_model(f, qrnn.quantiles)
qrnn.model = model

return qrnn

def save(self, path):
Expand All @@ -621,11 +624,12 @@ def save(self, path):
store the model.

"""
f = open(path, "wb")
pickle.dump(self, f)
backend = importlib.import_module(self.backend)
backend.save_model(f, self.model)
f.close()
with open(path + ".pkl", 'wb') as f:
pickle.dump(self, f)

with open(path + ".model", 'wb') as f:
backend = importlib.import_module(self.backend)
backend.save_model(f, self.model)


def __getstate__(self):
Expand Down
19 changes: 10 additions & 9 deletions typhon/tests/retrieval/qrnn/test_qrnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@
#

backends = []
try:
import typhon.retrieval.qrnn.models.keras

backends += ["keras"]
except:
pass
# try:
# import typhon.retrieval.qrnn.models.keras
#
# backends += ["keras"]
# except:
# pass

try:
import typhon.retrieval.qrnn.models.pytorch
Expand Down Expand Up @@ -87,9 +87,10 @@ def test_save_qrnn(self, backend):
"""
set_backend(backend)
qrnn = QRNN(self.x_train.shape[1], np.linspace(0.05, 0.95, 10))
f = tempfile.NamedTemporaryFile()
qrnn.save(f.name)
qrnn_loaded = QRNN.load(f.name)
with tempfile.TemporaryDirectory() as d:
f = os.path.join(d, "qrnn")
qrnn.save(f)
qrnn_loaded = QRNN.load(f)

x_pred = qrnn.predict(self.x_train)
x_pred_loaded = qrnn.predict(self.x_train)
Expand Down
Loading