Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

initial commit DRQN implementation #257

Merged
merged 2 commits into from
Mar 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions hive/agents/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from hive.agents import qnets
from hive.agents.agent import Agent
from hive.agents.dqn import DQNAgent
from hive.agents.drqn import DRQNAgent
from hive.agents.legal_moves_rainbow import LegalMovesRainbowAgent
from hive.agents.rainbow import RainbowDQNAgent
from hive.agents.random import RandomAgent
Expand All @@ -13,6 +14,7 @@
"LegalMovesRainbowAgent": LegalMovesRainbowAgent,
"RainbowDQNAgent": RainbowDQNAgent,
"RandomAgent": RandomAgent,
"DRQNAgent": DRQNAgent,
},
)

Expand Down
2 changes: 1 addition & 1 deletion hive/agents/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def __init__(
self._replay_buffer = replay_buffer
if self._replay_buffer is None:
self._replay_buffer = CircularReplayBuffer()
self._discount_rate = discount_rate ** n_step
self._discount_rate = discount_rate**n_step
self._grad_clip = grad_clip
self._reward_clip = reward_clip
self._target_net_soft_update = target_net_soft_update
Expand Down
288 changes: 288 additions & 0 deletions hive/agents/drqn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,288 @@
import copy
import os

import numpy as np
import torch

from hive.agents.agent import Agent
from hive.agents.qnets.base import FunctionApproximator
from hive.agents.qnets.qnet_heads import DQNNetwork
from hive.agents.qnets.utils import (
InitializationFn,
calculate_output_dim,
create_init_weights_fn,
)
from hive.replays import BaseReplayBuffer, CircularReplayBuffer
from hive.utils.loggers import Logger, NullLogger
from hive.utils.schedule import (
LinearSchedule,
PeriodicSchedule,
Schedule,
SwitchSchedule,
)
from hive.agents.dqn import DQNAgent
from hive.utils.utils import LossFn, OptimizerFn, create_folder, seeder


class DRQNAgent(DQNAgent):
"""An agent implementing the DQN algorithm. Uses an epsilon greedy
exploration policy
"""

def __init__(
self,
representation_net: FunctionApproximator,
obs_dim,
act_dim: int,
id=0,
optimizer_fn: OptimizerFn = None,
loss_fn: LossFn = None,
init_fn: InitializationFn = None,
replay_buffer: BaseReplayBuffer = None,
discount_rate: float = 0.99,
n_step: int = 1,
grad_clip: float = None,
reward_clip: float = None,
update_period_schedule: Schedule = None,
target_net_soft_update: bool = False,
target_net_update_fraction: float = 0.05,
target_net_update_schedule: Schedule = None,
epsilon_schedule: Schedule = None,
test_epsilon: float = 0.001,
min_replay_history: int = 5000,
batch_size: int = 32,
device="cpu",
logger: Logger = None,
log_frequency: int = 100,
):
"""
Args:
representation_net (FunctionApproximator): A network that outputs the
representations that will be used to compute Q-values (e.g.
everything except the final layer of the DQN).
obs_dim: The shape of the observations.
act_dim (int): The number of actions available to the agent.
id: Agent identifier.
optimizer_fn (OptimizerFn): A function that takes in a list of parameters
to optimize and returns the optimizer. If None, defaults to
:py:class:`~torch.optim.Adam`.
loss_fn (LossFn): Loss function used by the agent. If None, defaults to
:py:class:`~torch.nn.SmoothL1Loss`.
init_fn (InitializationFn): Initializes the weights of qnet using
create_init_weights_fn.
replay_buffer (BaseReplayBuffer): The replay buffer that the agent will
push observations to and sample from during learning. If None,
defaults to
:py:class:`~hive.replays.circular_replay.CircularReplayBuffer`.
discount_rate (float): A number between 0 and 1 specifying how much
future rewards are discounted by the agent.
n_step (int): The horizon used in n-step returns to compute TD(n) targets.
grad_clip (float): Gradients will be clipped to between
[-grad_clip, grad_clip].
reward_clip (float): Rewards will be clipped to between
[-reward_clip, reward_clip].
update_period_schedule (Schedule): Schedule determining how frequently
the agent's Q-network is updated.
target_net_soft_update (bool): Whether the target net parameters are
replaced by the qnet parameters completely or using a weighted
average of the target net parameters and the qnet parameters.
target_net_update_fraction (float): The weight given to the target
net parameters in a soft update.
target_net_update_schedule (Schedule): Schedule determining how frequently
the target net is updated.
epsilon_schedule (Schedule): Schedule determining the value of epsilon
through the course of training.
test_epsilon (float): epsilon (probability of choosing a random action)
to be used during testing phase.
min_replay_history (int): How many observations to fill the replay buffer
with before starting to learn.
batch_size (int): The size of the batch sampled from the replay buffer
during learning.
device: Device on which all computations should be run.
logger (ScheduledLogger): Logger used to log agent's metrics.
log_frequency (int): How often to log the agent's metrics.
"""
super().__init__(
representation_net=representation_net,
obs_dim=obs_dim,
act_dim=act_dim,
id=0,
optimizer_fn=optimizer_fn,
loss_fn=loss_fn,
init_fn=init_fn,
replay_buffer=replay_buffer,
discount_rate=discount_rate,
n_step=n_step,
grad_clip=grad_clip,
reward_clip=reward_clip,
update_period_schedule=update_period_schedule,
target_net_soft_update=target_net_soft_update,
target_net_update_fraction=target_net_update_fraction,
target_net_update_schedule=target_net_update_schedule,
epsilon_schedule=epsilon_schedule,
test_epsilon=test_epsilon,
min_replay_history=min_replay_history,
batch_size=batch_size,
device=device,
logger=logger,
log_frequency=log_frequency,
)

def create_q_networks(self, representation_net):
"""Creates the Q-network and target Q-network.

Args:
representation_net: A network that outputs the representations that will
be used to compute Q-values (e.g. everything except the final layer
of the DQN).
"""
network = representation_net(self._obs_dim)
network_output_dim = np.prod(calculate_output_dim(network, self._obs_dim))
self._qnet = DQNNetwork(
network, network_output_dim, self._act_dim, use_rnn=True
).to(self._device)
self._qnet.apply(self._init_fn)
self._target_qnet = copy.deepcopy(self._qnet).requires_grad_(False)

def preprocess_update_info(self, update_info):
"""Preprocesses the :obj:`update_info` before it goes into the replay buffer.
Clips the reward in update_info.

Args:
update_info: Contains the information from the current timestep that the
agent should use to update itself.
"""
if self._reward_clip is not None:
update_info["reward"] = np.clip(
update_info["reward"], -self._reward_clip, self._reward_clip
)
preprocessed_update_info = {
"observation": update_info["observation"],
"action": update_info["action"],
"reward": update_info["reward"],
"done": update_info["done"],
}
if "agent_id" in update_info:
preprocessed_update_info["agent_id"] = int(update_info["agent_id"])

return preprocessed_update_info

def preprocess_update_batch(self, batch):
"""Preprocess the batch sampled from the replay buffer.

Args:
batch: Batch sampled from the replay buffer for the current update.

Returns:
(tuple):
- (tuple) Inputs used to calculate current state values.
- (tuple) Inputs used to calculate next state values
- Preprocessed batch.
"""
for key in batch:
batch[key] = torch.tensor(batch[key], device=self._device)
return (batch["observation"],), (batch["next_observation"],), batch

@torch.no_grad()
def act(self, observation):
"""Returns the action for the agent. If in training mode, follows an epsilon
greedy policy. Otherwise, returns the action with the highest Q-value.

Args:
observation: The current observation.
"""

# Determine and log the value of epsilon
if self._training:
if not self._learn_schedule.get_value():
epsilon = 1.0
else:
epsilon = self._epsilon_schedule.update()
if self._logger.update_step(self._timescale):
self._logger.log_scalar("epsilon", epsilon, self._timescale)
else:
epsilon = self._test_epsilon

# Sample action. With epsilon probability choose random action,
# otherwise select the action with the highest q-value.
observation = torch.tensor(
np.expand_dims(observation, axis=0), device=self._device
).float()
qvals = self._qnet(observation)
if self._rng.random() < epsilon:
action = self._rng.integers(self._act_dim)
else:
# Note: not explicitly handling the ties
action = torch.argmax(qvals).item()

if (
self._training
and self._logger.should_log(self._timescale)
and self._state["episode_start"]
):
self._logger.log_scalar("train_qval", torch.max(qvals), self._timescale)
self._state["episode_start"] = False
return action

def update(self, update_info):
"""
Updates the DQN agent.

Args:
update_info: dictionary containing all the necessary information to
update the agent. Should contain a full transition, with keys for
"observation", "action", "reward", and "done".
"""
if update_info["done"]:
self._state["episode_start"] = True

if not self._training:
return

# Add the most recent transition to the replay buffer.
self._replay_buffer.add(**self.preprocess_update_info(update_info))

# Update the q network based on a sample batch from the replay buffer.
# If the replay buffer doesn't have enough samples, catch the exception
# and move on.
if (
self._learn_schedule.update()
and self._replay_buffer.size() > 0
and self._update_period_schedule.update()
):
batch = self._replay_buffer.sample(batch_size=self._batch_size)
(
current_state_inputs,
next_state_inputs,
batch,
) = self.preprocess_update_batch(batch)

# Compute predicted Q values
self._optimizer.zero_grad()
pred_qvals = self._qnet(*current_state_inputs)
actions = batch["action"].long()
pred_qvals = pred_qvals[torch.arange(pred_qvals.size(0)), actions]

# Compute 1-step Q targets
next_qvals = self._target_qnet(*next_state_inputs)
next_qvals, _ = torch.max(next_qvals, dim=1)

q_targets = batch["reward"] + self._discount_rate * next_qvals * (
1 - batch["done"]
)

loss = self._loss_fn(pred_qvals, q_targets).mean()

if self._logger.should_log(self._timescale):
self._logger.log_scalar("train_loss", loss, self._timescale)

loss.backward()
if self._grad_clip is not None:
torch.nn.utils.clip_grad_value_(
self._qnet.parameters(), self._grad_clip
)
self._optimizer.step()

# Update target network
if self._target_net_update_schedule.update():
self._update_target()
2 changes: 2 additions & 0 deletions hive/agents/qnets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
from hive.agents.qnets.base import FunctionApproximator
from hive.agents.qnets.conv import ConvNetwork
from hive.agents.qnets.mlp import MLPNetwork
from hive.agents.qnets.rnn import ConvRNNNetwork

registry.register_all(
FunctionApproximator,
{
"MLPNetwork": FunctionApproximator(MLPNetwork),
"ConvNetwork": FunctionApproximator(ConvNetwork),
"ConvRNNNetwork": FunctionApproximator(ConvRNNNetwork),
"NatureAtariDQNModel": FunctionApproximator(NatureAtariDQNModel),
},
)
Expand Down
Loading