Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for SB3 callbacks in adversarial training #786

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 24 additions & 12 deletions src/imitation/algorithms/adversarial/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import abc
import dataclasses
import logging
from typing import Callable, Iterable, Iterator, Mapping, Optional, Type, overload
from typing import Iterable, Iterator, List, Mapping, Optional, Type, overload

import numpy as np
import torch as th
Expand All @@ -15,6 +15,8 @@
policies,
vec_env,
)
from stable_baselines3.common.callbacks import BaseCallback, ConvertCallback
from stable_baselines3.common.type_aliases import MaybeCallback
from stable_baselines3.sac import policies as sac_policies
from torch.nn import functional as F

Expand Down Expand Up @@ -392,6 +394,7 @@ def train_gen(
self,
total_timesteps: Optional[int] = None,
learn_kwargs: Optional[Mapping] = None,
callback: MaybeCallback = None,
) -> None:
"""Trains the generator to maximize the discriminator loss.

Expand All @@ -404,17 +407,30 @@ def train_gen(
`self.gen_train_timesteps`.
learn_kwargs: kwargs for the Stable Baselines `RLModel.learn()`
method.
callback: additional callback(s) passed to the generator's `learn` method.
"""
if total_timesteps is None:
total_timesteps = self.gen_train_timesteps
if learn_kwargs is None:
learn_kwargs = {}

callbacks: List[BaseCallback] = []

if self.gen_callback:
callbacks.append(self.gen_callback)

if isinstance(callback, list):
callbacks.extend(callback)
elif isinstance(callback, BaseCallback):
callbacks.append(callback)
elif callback is not None:
callbacks.append(ConvertCallback(callback))

with self.logger.accumulate_means("gen"):
self.gen_algo.learn(
total_timesteps=total_timesteps,
reset_num_timesteps=False,
callback=self.gen_callback,
callback=callbacks,
**learn_kwargs,
)
self._global_step += 1
Expand All @@ -427,37 +443,33 @@ def train_gen(
def train(
self,
total_timesteps: int,
callback: Optional[Callable[[int], None]] = None,
callback: MaybeCallback = None,
) -> None:
"""Alternates between training the generator and discriminator.

Every "round" consists of a call to `train_gen(self.gen_train_timesteps)`,
a call to `train_disc`, and finally a call to `callback(round)`.
Every "round" consists of a call to
`train_gen(self.gen_train_timesteps, callback)`, then a call to `train_disc`.

Training ends once an additional "round" would cause the number of transitions
sampled from the environment to exceed `total_timesteps`.

Args:
total_timesteps: An upper bound on the number of transitions to sample
from the environment during training.
callback: A function called at the end of every round which takes in a
single argument, the round number. Round numbers are in
`range(total_timesteps // self.gen_train_timesteps)`.
callback: callback(s) passed to the generator's `learn` method.
"""
n_rounds = total_timesteps // self.gen_train_timesteps
assert n_rounds >= 1, (
"No updates (need at least "
f"{self.gen_train_timesteps} timesteps, have only "
f"total_timesteps={total_timesteps})!"
)
for r in tqdm.tqdm(range(0, n_rounds), desc="round"):
self.train_gen(self.gen_train_timesteps)
for _r in tqdm.tqdm(range(0, n_rounds), desc="round"):
self.train_gen(self.gen_train_timesteps, callback=callback)
for _ in range(self.n_disc_updates_per_round):
with networks.training(self.reward_train):
# switch to training mode (affects dropout, normalization)
self.train_disc()
if callback:
callback(r)
self.logger.dump(self._global_step)

@overload
Expand Down
31 changes: 27 additions & 4 deletions src/imitation/scripts/train_adversarial.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import sacred.commands
import torch as th
from sacred.observers import FileStorageObserver
from stable_baselines3.common.callbacks import BaseCallback

from imitation.algorithms.adversarial import airl as airl_algo
from imitation.algorithms.adversarial import common
Expand All @@ -22,6 +23,31 @@
logger = logging.getLogger("imitation.scripts.train_adversarial")


class CheckpointCallback(BaseCallback):
"""A callback for calling `save` at regular intervals."""

def __init__(
self,
trainer: common.AdversarialTrainer,
log_dir: pathlib.Path,
interval: int,
):
"""Creates new Checkpoint callback."""
super().__init__()
self.trainer = trainer
self.log_dir = log_dir
self.interval = interval
self.round_num = 0

def _on_step(self) -> bool:
return True

def _on_training_end(self) -> None:
self.round_num += 1
if self.interval > 0 and self.round_num % self.interval == 0:
save(self.trainer, self.log_dir / "checkpoints" / f"{self.round_num:05d}")


def save(trainer: common.AdversarialTrainer, save_path: pathlib.Path):
"""Save discriminator and generator."""
# We implement this here and not in Trainer since we do not want to actually
Expand Down Expand Up @@ -153,10 +179,7 @@ def train_adversarial(
**algorithm_kwargs,
)

def callback(round_num: int, /) -> None:
if checkpoint_interval > 0 and round_num % checkpoint_interval == 0:
save(trainer, log_dir / "checkpoints" / f"{round_num:05d}")

callback = CheckpointCallback(trainer, log_dir, checkpoint_interval)
trainer.train(total_timesteps, callback)
imit_stats = policy_evaluation.eval_policy(trainer.policy, trainer.venv_train)

Expand Down
36 changes: 36 additions & 0 deletions tests/algorithms/test_adversarial.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import stable_baselines3
import torch as th
from stable_baselines3.common import policies
from stable_baselines3.common.callbacks import BaseCallback
from torch.utils import data as th_data

from imitation.algorithms.adversarial import airl, common, gail
Expand Down Expand Up @@ -464,3 +465,38 @@ def test_regression_gail_with_sac(
reward_net=reward_net,
)
gail_trainer.train(8)


def test_gen_callback(trainer: common.AdversarialTrainer):
def make_fn_callback(calls, key):
def cb(_a, _b):
calls[key] += 1

return cb

class SB3Callback(BaseCallback):
def __init__(self, calls, key):
super().__init__(self)
self.calls = calls
self.key = key

def _on_step(self):
self.calls[self.key] += 1
return True

n_steps = trainer.gen_train_timesteps * 2
calls = {"fn": 0, "sb3": 0, "list.0": 0, "list.1": 0}

trainer.train(n_steps, callback=make_fn_callback(calls, "fn"))
trainer.train(n_steps, callback=SB3Callback(calls, "sb3"))
trainer.train(
n_steps,
callback=[SB3Callback(calls, "list.0"), SB3Callback(calls, "list.1")],
)

# Env steps for off-plicy algos (DQN) may exceed `total_timesteps`,
# so we check if the callback was called *at least* that many times.
assert calls["fn"] >= n_steps
assert calls["sb3"] >= n_steps
assert calls["list.0"] >= n_steps
assert calls["list.1"] >= n_steps