Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Model factory #351

Open
wants to merge 31 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
4ce9d7b
corrected bug in assertion: missing self.env
Jan 21, 2022
e230152
blacked main
riccardodv Feb 16, 2022
8cdec6f
DQN with user set Q network
May 4, 2022
25b425d
newer version of rlberry
May 4, 2022
1daf92f
Merge branch 'new_main' into main
May 4, 2022
0a17c00
More details in DQN docstring about Qnet and test on changing default…
May 4, 2022
cdfe2c0
blacked last commit
May 4, 2022
4a25e6f
None is more clearly stated in DQN docstring in q_net_constructor
May 4, 2022
1984059
Better docstring for DQN: explains what is str for q_net_constructor …
May 5, 2022
57ce593
blacked last commit
May 5, 2022
e4065bb
Merge remote-tracking branch 'upstream/main' into main
May 11, 2022
bbcca1a
Merge remote-tracking branch 'upstream/main' into main
May 12, 2022
eaf6343
Merge remote-tracking branch 'upstream/main' into main
May 13, 2022
6955cfc
Merge remote-tracking branch 'upstream/main' into main
May 16, 2022
b19a54a
Merge remote-tracking branch 'upstream/main' into main
Jun 7, 2022
0e65761
Merge remote-tracking branch 'upstream/main' into main
Jun 28, 2022
abe710d
Merge remote-tracking branch 'upstream/main' into main
Aug 12, 2022
42d5271
Merge branch 'main' of github.com:AleShi94/rlberry into main
Aug 12, 2022
c067475
Merge remote-tracking branch 'upstream/main' into main
Sep 7, 2022
e407a2f
Merge remote-tracking branch 'upstream/main'
Oct 11, 2022
1654aca
Merge remote-tracking branch 'upstream/main'
Dec 15, 2022
564b05f
Merge remote-tracking branch 'upstream/main' into main
Dec 19, 2022
b72c7fd
Merge remote-tracking branch 'upstream/main'
Jul 24, 2023
4d442c2
model factory can take externally defined nn and load it from file + …
Jul 24, 2023
c278e6c
blacked
Jul 24, 2023
19b9bfd
more coverage
Jul 24, 2023
b99188d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 24, 2023
215bec5
blacked
Jul 24, 2023
5d1ec09
Merge remote-tracking branch 'upstream/main' into model_factory
Jul 24, 2023
925b040
Merge branch 'model_factory' of github.com:AleShi94/rlberry into mode…
Jul 24, 2023
411d3ca
flake 8 should be fine
Jul 24, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 112 additions & 2 deletions rlberry/agents/torch/tests/test_torch_training.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,23 @@
import torch
from rlberry.agents.torch.utils.training import loss_function_factory, optimizer_factory

import os
from rlberry.agents.torch.utils.training import (
loss_function_factory,
optimizer_factory,
model_factory,
model_factory_from_env,
check_network,
)
from rlberry.envs.benchmarks.ball_exploration.ball2d import get_benchmark_env
from rlberry.agents.torch.utils.models import default_policy_net_fn
from rlberry.envs.finite import Chain
from rlberry.envs import gym_make
from rlberry.agents.torch.utils.models import (
default_policy_net_fn,
Net,
MultiLayerPerceptron,
)
from rlberry.agents.torch.dqn import DQNAgent


# loss_function_factory
assert isinstance(loss_function_factory("l2"), torch.nn.MSELoss)
Expand All @@ -11,6 +27,10 @@

# optimizer_factory
env = get_benchmark_env(level=1)

finite_env = Chain()

cont_act_env = gym_make("Pendulum-v1")
assert (
optimizer_factory(default_policy_net_fn(env).parameters(), "ADAM").defaults["lr"]
== 0.001
Expand All @@ -30,3 +50,93 @@
]
== 0.99
)


# test model_factory

obs_shape = env.observation_space.shape
n_act = env.action_space.n

test_net = Net(obs_size=obs_shape[0], hidden_size=10, n_actions=n_act)

test_net2 = MultiLayerPerceptron(in_size=obs_shape[0], layer_sizes=[10], out_size=1)


test_net3 = MultiLayerPerceptron(
in_size=obs_shape[0], layer_sizes=[10], out_size=n_act, is_policy=True
)

test_net4 = MultiLayerPerceptron(in_size=100, layer_sizes=[10], out_size=n_act)

test_net5 = MultiLayerPerceptron(
in_size=cont_act_env.observation_space.shape[0],
layer_sizes=[10],
out_size=cont_act_env.action_space.shape[0],
)


model_factory(net=test_net)
model_factory_from_env(env, net=test_net)
model_factory_from_env(env, net=test_net2, out_size=1)
model_factory_from_env(env, net=test_net3, is_policy=True)
model_factory_from_env(cont_act_env, net=test_net5)


# test loading pretrained nn
dqn_agent = DQNAgent(
env, q_net_constructor=model_factory_from_env, q_net_kwargs=dict(net=test_net)
)

dqn_agent.fit(50)

torch.save(dqn_agent._qnet_online, "test_dqn.pickle")


parameters_to_save = dqn_agent._qnet_online.state_dict()
torch.save(parameters_to_save, "test_dqn.pt")
torch.save((parameters_to_save, parameters_to_save), "test_dqn2.pt")

try:
model_factory(filename="test_dqn2.pt")
except Exception as err:
os.remove("test_dqn2.pt")
print(err, "Bad file was removed.")

try:
model_factory(type="dummy")
except Exception as err:
print(err)


# This test should fail as
# try:
# check_network(cont_act_env, test_net)
# except Exception as err:
# print(err)


model_factory(filename="test_dqn.pickle")
model_factory(net=test_net, filename="test_dqn.pt")


dqn_agent = DQNAgent(
env,
q_net_constructor=model_factory_from_env,
q_net_kwargs=dict(filename="test_dqn.pickle"),
)

dqn_agent = DQNAgent(
env,
q_net_constructor=model_factory_from_env,
q_net_kwargs=dict(net=test_net, filename="test_dqn.pt"),
)

assert dqn_agent._qnet_online.state_dict().keys() == parameters_to_save.keys()

for k in parameters_to_save.keys():
assert (dqn_agent._qnet_online.state_dict()[k] == parameters_to_save[k]).all()

os.remove("test_dqn.pickle")
os.remove("test_dqn.pt")

print("done")
160 changes: 144 additions & 16 deletions rlberry/agents/torch/utils/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,21 +27,69 @@
raise ValueError("Unknown optimizer type: {}".format(optimizer_type))


def model_factory_from_env(env, **kwargs):
def model_factory_from_env(
env, type="MultiLayerPerceptron", net=None, filename=None, **net_kwargs
):
"""Returns a torch module after setting up input/output dimensions according to an env.

Parameters
----------
env: gym.Env
Environment
type: {"MultiLayerPerceptron",
"ConvolutionalNetwork",
"DuelingNetwork",
"Table"}, default = "MultiLayerPerceptron"
Type of neural network.
net: torch.nn.Module or None
If not None, return this neural network. It can be used to pass user-defined neural network.
filename: str or None
The path to a saved module or its 'state_dict'. If not None, it will load a net or a checkpoint.
**kwargs: Dict
Parameters to be updated, used to call :func:`~rlberry.agents.torch.utils.training.model_factory`.
"""
kwargs = size_model_config(env, **kwargs)
return model_factory(**kwargs)

if filename is not None:
load_dict = load_from_file(filename)
if load_dict["model"] is not None:
net = load_dict["model"]
checkpoint = load_dict["checkpoint"]
else:
checkpoint = None

kwargs = size_model_config(env, type, **net_kwargs)

if net is not None:
check_network(env, net, **kwargs)

return model_factory(type, net, checkpoint=checkpoint, **kwargs)


def load_from_file(filename):
"""Load a module or a checkpoint.

Parameters
----------
filename: str
The path to a saved module or its 'state_dict'. It will load a net or a checkpoint.
"""
output_dict = dict(model=None, checkpoint=None)

loaded = torch.load(filename)
if isinstance(loaded, torch.nn.Module):
output_dict["model"] = loaded
elif isinstance(loaded, dict):
output_dict["checkpoint"] = loaded
else:
raise ValueError(
"Invalid 'load_from_file'. File is expected to store either an entire model or its 'state_dict'."
)
return output_dict


def model_factory(type="MultiLayerPerceptron", **kwargs) -> nn.Module:
def model_factory(
type="MultiLayerPerceptron", net=None, filename=None, checkpoint=None, **net_kwargs
) -> nn.Module:
"""Build a neural net of a given type.

Parameters
Expand All @@ -51,7 +99,13 @@
"DuelingNetwork",
"Table"}, default = "MultiLayerPerceptron"
Type of neural network.
**kwargs: dict
net: torch.nn.Module or None
If not None, return this neural network. It can be used to pass user-defined neural network.
filename: str or None
The path to a saved module or its 'state_dict'. If not None, it will load a net or a checkpoint.
checkpoint: dict or None
If not None, then it is treated as a 'state_dict' that is assigned to a neural network model.
**net_kwargs: dict
Parameters that vary according to each neural net type, see

* :class:`~rlberry.agents.torch.utils.models.MultiLayerPerceptron`
Expand All @@ -69,19 +123,91 @@
Table,
)

if type == "MultiLayerPerceptron":
return MultiLayerPerceptron(**kwargs)
elif type == "DuelingNetwork":
return DuelingNetwork(**kwargs)
elif type == "ConvolutionalNetwork":
return ConvolutionalNetwork(**kwargs)
elif type == "Table":
return Table(**kwargs)
if filename is not None:
load_dict = load_from_file(filename)
if load_dict["model"] is not None:
return load_dict["model"]
else:
checkpoint = load_dict["checkpoint"]

if net is not None:
model = net
else:
raise ValueError("Unknown model type")
if type == "MultiLayerPerceptron":
model = MultiLayerPerceptron(**net_kwargs)
elif type == "DuelingNetwork":
model = DuelingNetwork(**net_kwargs)
elif type == "ConvolutionalNetwork":
model = ConvolutionalNetwork(**net_kwargs)
elif type == "Table":
model = Table(**net_kwargs)
else:
raise ValueError("Unknown model type")

if checkpoint is not None:
model.load_state_dict(checkpoint)

return model


def check_network(env, net, **model_config):
"""
Check the neural network that it satisfies the environment and predefined model_config. If the network is not good, it should raise an error.

Parameters
----------
env : gym.Env
An environment.
net: torch.nn.Module
A neural network.
model_config : dict
Desired parameters.
"""

if isinstance(env.observation_space, spaces.Box):
obs_shape = env.observation_space.shape
else:
raise NotImplementedError

Check warning on line 170 in rlberry/agents/torch/utils/training.py

View check run for this annotation

Codecov / codecov/patch

rlberry/agents/torch/utils/training.py#L170

Added line #L170 was not covered by tests
# elif isinstance(env.observation_space, spaces.Tuple):
# obs_shape = env.observation_space.spaces[0].shape
# elif isinstance(env.observation_space, spaces.Discrete):
# return model_config

if net is not None:
# check that it is compliant with environment
# input check
fake_input = torch.zeros(1, *obs_shape)
try:
output = net(fake_input)
except Exception as err:
print(

Check warning on line 183 in rlberry/agents/torch/utils/training.py

View check run for this annotation

Codecov / codecov/patch

rlberry/agents/torch/utils/training.py#L182-L183

Added lines #L182 - L183 were not covered by tests
f"NN input is not compatible with the environment. Got an error {err=}, {type(err)=}"
)
raise

Check warning on line 186 in rlberry/agents/torch/utils/training.py

View check run for this annotation

Codecov / codecov/patch

rlberry/agents/torch/utils/training.py#L186

Added line #L186 was not covered by tests
# output check
if "is_policy" in model_config:
is_policy = model_config["is_policy"]
if is_policy:
assert isinstance(
output, torch.distributions.distribution.Distribution
), "Policy should return distribution over actions"
else:
if "out_size" in model_config:
out_size = [model_config["out_size"]]
else:
if isinstance(env.action_space, spaces.Discrete):
out_size = [env.action_space.n]

Check warning on line 199 in rlberry/agents/torch/utils/training.py

View check run for this annotation

Codecov / codecov/patch

rlberry/agents/torch/utils/training.py#L199

Added line #L199 was not covered by tests
elif isinstance(env.action_space, spaces.Tuple):
out_size = [env.action_space.spaces[0].n]

Check warning on line 201 in rlberry/agents/torch/utils/training.py

View check run for this annotation

Codecov / codecov/patch

rlberry/agents/torch/utils/training.py#L201

Added line #L201 was not covered by tests
elif isinstance(env.action_space, spaces.Box):
out_size = env.action_space.shape
assert output.shape == (
1,
*out_size,
), f"Output should be of size {out_size}"


def size_model_config(env, **model_config):
def size_model_config(env, type=None, **model_config):
"""
Setup input/output dimensions for the configuration of
a model depending on the environment observation/action spaces.
Expand All @@ -90,6 +216,8 @@
----------
env : gym.Env
An environment.
type: str or None
Make configs corresponding to the chosen type of neural network.
model_config : dict
Parameters to be updated, used to call :func:`~rlberry.agents.torch.utils.training.model_factory`.
If "out_size" is not given in model_config, assumes
Expand All @@ -105,7 +233,7 @@
return model_config

# Assume CHW observation space
if "type" in model_config and model_config["type"] == "ConvolutionalNetwork":
if type == "ConvolutionalNetwork":
if "transpose_obs" in model_config and not model_config["transpose_obs"]:
# Assume CHW observation space
if "in_channels" not in model_config:
Expand Down