Skip to content

Commit

Permalink
Merge pull request #58 from Eclectic-Sheep/feature/save_args
Browse files Browse the repository at this point in the history
Save CLI args
  • Loading branch information
belerico authored Jul 17, 2023
2 parents 5c28ae9 + 3a0c4b7 commit 4239112
Show file tree
Hide file tree
Showing 20 changed files with 60 additions and 10 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/cpu-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ jobs:
- name: Install packages
run: |
pip install -U pip
pip install .[atari,test]
python -m pip install -U pip
python -m pip install .[atari,test]
- name: Run tests
run: |
Expand Down
15 changes: 13 additions & 2 deletions sheeprl/algos/args.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from dataclasses import dataclass
from typing import Optional
import json
import os
from dataclasses import asdict, dataclass
from typing import Any, Optional

from sheeprl.utils.parser import Arg

Expand All @@ -24,3 +26,12 @@ class StandardArgs:
help="whether to move the buffer to the shared memory. "
"Useful for pixel-based off-policy methods with large buffer size (>=1e6).",
)
checkpoint_every: int = Arg(default=100, help="how often to make the checkpoint, -1 to deactivate the checkpoint")
checkpoint_path: Optional[str] = Arg(default=None, help="the path of the checkpoint from which you want to restart")

def __setattr__(self, __name: str, __value: Any) -> None:
super().__setattr__(__name, __value)
if __name == "log_dir":
file_name = os.path.join(__value, "args.json")
os.makedirs(__value, exist_ok=True)
json.dump(asdict(self), open(file_name, "w"))
2 changes: 0 additions & 2 deletions sheeprl/algos/dreamer_v1/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,7 @@ class DreamerV1Args(StandardArgs):
learning_starts: int = Arg(default=int(5e3), help="timestep to start learning")
gradient_steps: int = Arg(default=100, help="the number of gradient steps per each environment interaction")
train_every: int = Arg(default=1000, help="the number of steps between one training and another")
checkpoint_every: int = Arg(default=-1, help="how often to make the checkpoint, -1 to deactivate the checkpoint")
checkpoint_buffer: bool = Arg(default=False, help="whether or not to save the buffer during the checkpoint")
checkpoint_path: Optional[str] = Arg(default=None, help="the path of the checkpoint from which you want to restart")

# Agent settings
world_lr: float = Arg(default=6e-4, help="the learning rate of the optimizer of the world model")
Expand Down
3 changes: 3 additions & 0 deletions sheeprl/algos/dreamer_v1/dreamer_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,9 @@ def main():
fabric.logger.log_hyperparams(asdict(args))
if fabric.world_size > 1:
world_collective.broadcast_object_list([log_dir], src=0)

# Save args as dict automatically
args.log_dir = log_dir
else:
data = [None]
world_collective.broadcast_object_list(data, src=0)
Expand Down
2 changes: 0 additions & 2 deletions sheeprl/algos/dreamer_v2/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,7 @@ class DreamerV2Args(StandardArgs):
pretrain_steps: int = Arg(default=100, help="the number of pretrain steps")
gradient_steps: int = Arg(default=1, help="the number of gradient steps per each environment interaction")
train_every: int = Arg(default=5, help="the number of steps between one training and another")
checkpoint_every: int = Arg(default=-1, help="how often to make the checkpoint, -1 to deactivate the checkpoint")
checkpoint_buffer: bool = Arg(default=False, help="whether or not to save the buffer during the checkpoint")
checkpoint_path: Optional[str] = Arg(default=None, help="the path of the checkpoint from which you want to restart")
buffer_type: str = Arg(
default="sequential",
help="which buffer to use: `sequential` or `episode`. The `episode` "
Expand Down
6 changes: 6 additions & 0 deletions sheeprl/algos/dreamer_v2/dreamer_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,12 +416,18 @@ def main():
fabric.logger.log_hyperparams(asdict(args))
if fabric.world_size > 1:
world_collective.broadcast_object_list([log_dir], src=0)

# Save args as dict automatically
args.log_dir = log_dir
else:
data = [None]
world_collective.broadcast_object_list(data, src=0)
log_dir = data[0]
os.makedirs(log_dir, exist_ok=True)

# Save args as dict automatically
args.log_dir = log_dir

env: gym.Env = make_env(
args.env_id,
args.seed + rank * args.num_envs,
Expand Down
3 changes: 3 additions & 0 deletions sheeprl/algos/droq/droq.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,9 @@ def main():
fabric.logger.log_hyperparams(asdict(args))
if fabric.world_size > 1:
world_collective.broadcast_object_list([log_dir], src=0)

# Save args as dict automatically
args.log_dir = log_dir
else:
data = [None]
world_collective.broadcast_object_list(data, src=0)
Expand Down
3 changes: 3 additions & 0 deletions sheeprl/algos/p2e/p2e_dv1/p2e_dv1.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,9 @@ def main():
fabric.logger.log_hyperparams(asdict(args))
if fabric.world_size > 1:
world_collective.broadcast_object_list([log_dir], src=0)

# Save args as dict automatically
args.log_dir = log_dir
else:
data = [None]
world_collective.broadcast_object_list(data, src=0)
Expand Down
1 change: 0 additions & 1 deletion sheeprl/algos/ppo/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,5 @@ class PPOArgs(StandardArgs):
anneal_ent_coef: bool = Arg(default=False, help="whether to linearly anneal the entropy coefficient to zero")
vf_coef: float = Arg(default=1.0, help="coefficient of the value function")
max_grad_norm: float = Arg(default=0.0, help="the maximum norm for the gradient clipping")
checkpoint_every: int = Arg(default=-1, help="how often to make the checkpoint, -1 to deactivate the checkpoint")
actor_hidden_size: int = Arg(default=64, help="the dimension of the hidden sizes of the actor network")
critic_hidden_size: int = Arg(default=64, help="the dimension of the hidden sizes of the critic network")
3 changes: 3 additions & 0 deletions sheeprl/algos/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,9 @@ def main():
fabric.logger.log_hyperparams(asdict(args))
if fabric.world_size > 1:
world_collective.broadcast_object_list([log_dir], src=0)

# Save args as dict automatically
args.log_dir = log_dir
else:
data = [None]
world_collective.broadcast_object_list(data, src=0)
Expand Down
3 changes: 3 additions & 0 deletions sheeprl/algos/ppo/ppo_decoupled.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ def player(args: PPOArgs, world_collective: TorchCollective, player_trainer_coll
logger = TensorBoardLogger(root_dir=root_dir, name=run_name)
logger.log_hyperparams(asdict(args))

# Save args as dict automatically
args.log_dir = logger.log_dir

# Initialize Fabric object
fabric = Fabric(loggers=logger, callbacks=[CheckpointCallback()])
if not _is_using_cli():
Expand Down
3 changes: 3 additions & 0 deletions sheeprl/algos/ppo_continuous/ppo_continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,9 @@ def main():
fabric.logger.log_hyperparams(asdict(args))
if fabric.world_size > 1:
world_collective.broadcast_object_list([log_dir], src=0)

# Save args as dict automatically
args.log_dir = log_dir
else:
data = [None]
world_collective.broadcast_object_list(data, src=0)
Expand Down
3 changes: 3 additions & 0 deletions sheeprl/algos/ppo_pixel/ppo_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,9 @@ def player(args: PPOAtariArgs, world_collective: TorchCollective, player_trainer
logger = TensorBoardLogger(root_dir=root_dir, name=run_name)
logger.log_hyperparams(asdict(args))

# Save args as dict automatically
args.log_dir = logger.log_dir

# Initialize Fabric object
fabric = Fabric(loggers=logger, callbacks=[CheckpointCallback()])
if not _is_using_cli():
Expand Down
3 changes: 3 additions & 0 deletions sheeprl/algos/ppo_pixel/ppo_pixel_continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,9 @@ def player(args: PPOPixelContinuousArgs, world_collective: TorchCollective, play
logger = TensorBoardLogger(root_dir=root_dir, name=run_name)
logger.log_hyperparams(asdict(args))

# Save args as dict automatically
args.log_dir = logger.log_dir

# Initialize Fabric object
fabric = Fabric(loggers=logger, callbacks=[CheckpointCallback()])
if not _is_using_cli():
Expand Down
3 changes: 3 additions & 0 deletions sheeprl/algos/ppo_recurrent/ppo_recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,9 @@ def main():
fabric.logger.log_hyperparams(asdict(args))
if fabric.world_size > 1:
world_collective.broadcast_object_list([log_dir], src=0)

# Save args as dict automatically
args.log_dir = log_dir
else:
data = [None]
world_collective.broadcast_object_list(data, src=0)
Expand Down
1 change: 0 additions & 1 deletion sheeprl/algos/sac/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ class SACArgs(StandardArgs):
policy_lr: float = Arg(default=3e-4, help="the learning rate of the entropy coefficient optimizer")
target_network_frequency: int = Arg(default=1, help="the frequency of updates for the target nerworks")
gradient_steps: int = Arg(default=1, help="the number of gradient steps per each environment interaction")
checkpoint_every: int = Arg(default=-1, help="how often to make the checkpoint, -1 to deactivate the checkpoint")
checkpoint_buffer: bool = Arg(default=False, help="whether or not to save the buffer during the checkpoint")
sample_next_obs: bool = Arg(
default=False, help="whether or not to sample the next observations from the gathered observations"
Expand Down
3 changes: 3 additions & 0 deletions sheeprl/algos/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,9 @@ def main():
fabric.logger.log_hyperparams(asdict(args))
if fabric.world_size > 1:
world_collective.broadcast_object_list([log_dir], src=0)

# Save args as dict automatically
args.log_dir = log_dir
else:
data = [None]
world_collective.broadcast_object_list(data, src=0)
Expand Down
3 changes: 3 additions & 0 deletions sheeprl/algos/sac/sac_decoupled.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ def player(args: SACArgs, world_collective: TorchCollective, player_trainer_coll
logger = TensorBoardLogger(root_dir=root_dir, name=run_name)
logger.log_hyperparams(asdict(args))

# Save args as dict automatically
args.log_dir = logger.log_dir

# Initialize Fabric
fabric = Fabric(loggers=logger, callbacks=[CheckpointCallback()])
if not _is_using_cli():
Expand Down
3 changes: 3 additions & 0 deletions sheeprl/algos/sac_pixel/sac_pixel_continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,9 @@ def main():
fabric.logger.log_hyperparams(asdict(args))
if fabric.world_size > 1:
world_collective.broadcast_object_list([log_dir], src=0)

# Save args as dict automatically
args.log_dir = log_dir
else:
data = [None]
world_collective.broadcast_object_list(data, src=0)
Expand Down
3 changes: 3 additions & 0 deletions tests/test_algos/test_algos.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ def check_checkpoint(ckpt_path: str, target_keys: set, checkpoint_buffer: bool =
# if checkpoint_buffer is false, then "rb" cannot be in the checkpoint keys
assert checkpoint_buffer or "rb" not in ckpt_keys

# check args are saved
assert os.path.exists(os.path.join(os.path.dirname(ckpt_path), "args.json"))


@pytest.mark.timeout(60)
@pytest.mark.parametrize("checkpoint_buffer", [True, False])
Expand Down

0 comments on commit 4239112

Please sign in to comment.