Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added array-like min and max actions #29

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions DDPG.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,21 @@


class Actor(nn.Module):
def __init__(self, state_dim, action_dim, max_action):
def __init__(self, state_dim, action_dim, min_action, max_action):
super(Actor, self).__init__()

self.l1 = nn.Linear(state_dim, 400)
self.l2 = nn.Linear(400, 300)
self.l3 = nn.Linear(300, action_dim)

self.max_action = max_action
self.min_action = torch.FloatTensor(min_action)
self.max_action = torch.FloatTensor(max_action)


def forward(self, state):
a = F.relu(self.l1(state))
a = F.relu(self.l2(a))
return self.max_action * torch.tanh(self.l3(a))
return (self.max_action - self.min_action) * ((torch.tanh(self.l3(a)) + 1) / 2) + self.min_action


class Critic(nn.Module):
Expand All @@ -45,8 +46,8 @@ def forward(self, state, action):


class DDPG(object):
def __init__(self, state_dim, action_dim, max_action, discount=0.99, tau=0.001):
self.actor = Actor(state_dim, action_dim, max_action).to(device)
def __init__(self, state_dim, action_dim, min_action, max_action, discount=0.99, tau=0.001):
self.actor = Actor(state_dim, action_dim, min_action, max_action).to(device)
self.actor_target = copy.deepcopy(self.actor)
self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=1e-4)

Expand Down Expand Up @@ -114,4 +115,4 @@ def load(self, filename):
self.actor.load_state_dict(torch.load(filename + "_actor"))
self.actor_optimizer.load_state_dict(torch.load(filename + "_actor_optimizer"))
self.actor_target = copy.deepcopy(self.actor)

13 changes: 7 additions & 6 deletions OurDDPG.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,21 @@


class Actor(nn.Module):
def __init__(self, state_dim, action_dim, max_action):
def __init__(self, state_dim, action_dim, min_action, max_action):
super(Actor, self).__init__()

self.l1 = nn.Linear(state_dim, 400)
self.l2 = nn.Linear(400, 300)
self.l3 = nn.Linear(300, action_dim)

self.max_action = max_action
self.min_action = torch.FloatTensor(min_action)
self.max_action = torch.FloatTensor(max_action)


def forward(self, state):
a = F.relu(self.l1(state))
a = F.relu(self.l2(a))
return self.max_action * torch.tanh(self.l3(a))
return (self.max_action - self.min_action) * ((torch.tanh(self.l3(a)) + 1) / 2) + self.min_action


class Critic(nn.Module):
Expand All @@ -44,8 +45,8 @@ def forward(self, state, action):


class DDPG(object):
def __init__(self, state_dim, action_dim, max_action, discount=0.99, tau=0.005):
self.actor = Actor(state_dim, action_dim, max_action).to(device)
def __init__(self, state_dim, action_dim, min_action, max_action, discount=0.99, tau=0.005):
self.actor = Actor(state_dim, action_dim, min_action, max_action).to(device)
self.actor_target = copy.deepcopy(self.actor)
self.actor_optimizer = torch.optim.Adam(self.actor.parameters())

Expand Down Expand Up @@ -113,4 +114,4 @@ def load(self, filename):
self.actor.load_state_dict(torch.load(filename + "_actor"))
self.actor_optimizer.load_state_dict(torch.load(filename + "_actor_optimizer"))
self.actor_target = copy.deepcopy(self.actor)


31 changes: 17 additions & 14 deletions TD3.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,21 @@


class Actor(nn.Module):
def __init__(self, state_dim, action_dim, max_action):
def __init__(self, state_dim, action_dim, min_action, max_action):
super(Actor, self).__init__()

self.l1 = nn.Linear(state_dim, 256)
self.l2 = nn.Linear(256, 256)
self.l3 = nn.Linear(256, action_dim)

self.max_action = max_action
self.min_action = torch.FloatTensor(min_action)
self.max_action = torch.FloatTensor(max_action)


def forward(self, state):
a = F.relu(self.l1(state))
a = F.relu(self.l2(a))
return self.max_action * torch.tanh(self.l3(a))
return (self.max_action - self.min_action) * ((torch.tanh(self.l3(a)) + 1) / 2) + self.min_action

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, may I ask the reason why you change here? I have ever tested your code, and comparing with the original one, the performace becomes worse, and the output action results seem very weird.



class Critic(nn.Module):
Expand Down Expand Up @@ -70,6 +71,7 @@ def __init__(
self,
state_dim,
action_dim,
min_action,
max_action,
discount=0.99,
tau=0.005,
Expand All @@ -78,19 +80,20 @@ def __init__(
policy_freq=2
):

self.actor = Actor(state_dim, action_dim, max_action).to(device)
self.actor = Actor(state_dim, action_dim, min_action, max_action).to(device)
self.actor_target = copy.deepcopy(self.actor)
self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=3e-4)

self.critic = Critic(state_dim, action_dim).to(device)
self.critic_target = copy.deepcopy(self.critic)
self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=3e-4)

self.max_action = max_action
self.min_action = torch.FloatTensor(min_action)
self.max_action = torch.FloatTensor(max_action)
self.discount = discount
self.tau = tau
self.policy_noise = policy_noise
self.noise_clip = noise_clip
self.policy_noise = torch.FloatTensor(policy_noise)
self.noise_clip = torch.FloatTensor(noise_clip)
self.policy_freq = policy_freq

self.total_it = 0
Expand All @@ -109,13 +112,13 @@ def train(self, replay_buffer, batch_size=100):

with torch.no_grad():
# Select action according to policy and add clipped noise
noise = (
torch.randn_like(action) * self.policy_noise
).clamp(-self.noise_clip, self.noise_clip)
noise = torch.max(
torch.min(torch.randn_like(action) * self.policy_noise, self.noise_clip
), -self.noise_clip)

next_action = (
self.actor_target(next_state) + noise
).clamp(-self.max_action, self.max_action)
next_action = torch.max(
torch.min(self.actor_target(next_state) + noise, self.max_action
), self.min_action)

# Compute the target Q value
target_Q1, target_Q2 = self.critic_target(next_state, next_action)
Expand Down Expand Up @@ -168,4 +171,4 @@ def load(self, filename):
self.actor.load_state_dict(torch.load(filename + "_actor"))
self.actor_optimizer.load_state_dict(torch.load(filename + "_actor_optimizer"))
self.actor_target = copy.deepcopy(self.actor)


12 changes: 7 additions & 5 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,13 @@ def eval_policy(policy, env_name, seed, eval_episodes=10):

state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
max_action = float(env.action_space.high[0])
min_action = env.action_space.low
max_action = env.action_space.high

kwargs = {
"state_dim": state_dim,
"action_dim": action_dim,
"min_action": min_action,
"max_action": max_action,
"discount": args.discount,
"tau": args.tau,
Expand All @@ -85,8 +87,8 @@ def eval_policy(policy, env_name, seed, eval_episodes=10):
# Initialize policy
if args.policy == "TD3":
# Target policy smoothing is scaled wrt the action scale
kwargs["policy_noise"] = args.policy_noise * max_action
kwargs["noise_clip"] = args.noise_clip * max_action
kwargs["policy_noise"] = args.policy_noise * (max_action - min_action) / 2
kwargs["noise_clip"] = args.noise_clip * (max_action - min_action) / 2
kwargs["policy_freq"] = args.policy_freq
policy = TD3.TD3(**kwargs)
elif args.policy == "OurDDPG":
Expand Down Expand Up @@ -118,8 +120,8 @@ def eval_policy(policy, env_name, seed, eval_episodes=10):
else:
action = (
policy.select_action(np.array(state))
+ np.random.normal(0, max_action * args.expl_noise, size=action_dim)
).clip(-max_action, max_action)
+ np.random.normal(0, (max_action - min_action) * args.expl_noise / 2, size=action_dim)
).clip(min_action, max_action)

# Perform action
next_state, reward, done, _ = env.step(action)
Expand Down