Skip to content

Commit

Permalink
Make Unity Environment support heterogeneous observations and multipl…
Browse files Browse the repository at this point in the history
…e agents.
  • Loading branch information
hyerra committed Jun 25, 2023
1 parent 8f46429 commit e2466bd
Showing 1 changed file with 149 additions and 125 deletions.
274 changes: 149 additions & 125 deletions torchrl/envs/libs/unity.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def _unity_to_torchrl_spec_transform(spec, dtype=None, device="cpu"):
if not len(shape):
shape = torch.Size([1])
dtype = numpy_to_torch_dtype_dict[dtype]
return UnboundedContinuousTensorSpec(shape, device=device, dtype=dtype)
return UnboundedContinuousTensorSpec(shape=shape, device=device, dtype=dtype)
elif isinstance(spec, ActionSpec):
if spec.continuous_size == len(spec.discrete_branches) == 0:
raise ValueError("No available actions")
Expand All @@ -51,7 +51,7 @@ def _unity_to_torchrl_spec_transform(spec, dtype=None, device="cpu"):
spec.discrete_branches, shape=[spec.discrete_size], device=device
)
# FIXME: Need tuple support as action masks are 2D arrays
# action_mask_spec = MultiDiscreteTensorSpec(spec.discrete_branches, dtype=torch.bool, device=device)
# noqa: E501 action_mask_spec = MultiDiscreteTensorSpec(spec.discrete_branches, dtype=torch.bool, device=device)

if spec.continuous_size > 0:
dtype = numpy_to_torch_dtype_dict[dtype]
Expand Down Expand Up @@ -91,7 +91,14 @@ def __init__(self, env=None, **kwargs):
super().__init__(**kwargs)

def _init_env(self):
self._behavior_names = []
pass

def _compute_num_agents(self, env):
num_agents = 0
for behavior_name in env.behavior_specs.keys():
decision_steps, terminal_steps = env.get_steps(behavior_name)
num_agents += len(decision_steps) + len(terminal_steps)
return num_agents

def _set_seed(self, seed: int | None):
warn(
Expand All @@ -107,52 +114,75 @@ def _build_env(self, env: BaseEnv):
if not env.behavior_specs:
# Take a single step so that the brain information will be sent over
env.step()
self._behavior_names = list(env.behavior_specs.keys())
self.num_agents = self._compute_num_agents(env)
self._agent_id_to_behavior = {}
return env

def _make_specs(self, env: BaseEnv) -> None:
# TODO: Behavior specs are immutable but new ones
# can be added if they are created in the environment.
# Need to account for behavior specs that are added
# throughout the environment lifecycle.

# IMPORTANT: This assumes that all agents have the same
# observations and actions. To change this, we need
# some method to allow for different specs depending on
# agent.
#
# A different `Parallel` version of this environment could be
# made where the number of agents is fixed, and then you stack
# all of the observations together. This design would allow
# different observations and actions, but would require
# a fixed agent count. The difficulty with implementing a
# `Parallel` version though is that not all agents will request
# a decision, so the spec would have to change depending
# on which agents request a decision.

first_behavior_name = next(iter(env.behavior_specs.keys()))
behavior_unity_spec = env.behavior_specs[first_behavior_name]
observation_specs = [
_unity_to_torchrl_spec_transform(
spec, dtype=np.dtype("float32"), device=self.device
)
for spec in behavior_unity_spec.observation_specs
]
behavior_id_spec = UnboundedDiscreteTensorSpec(1, device=self.device)
agent_id_spec = UnboundedDiscreteTensorSpec(1, device=self.device)
# FIXME: Need Tuple support here so we can support observations of varying dimensions.
# Thus, for now we use only the first observation.
observation_specs = [None] * self.num_agents
behavior_id_specs = [None] * self.num_agents
agent_id_specs = [None] * self.num_agents
action_specs = [None] * self.num_agents
action_mask_specs = [None] * self.num_agents
reward_specs = [None] * self.num_agents
done_specs = [None] * self.num_agents
valid_mask_specs = [None] * self.num_agents

for behavior_name, behavior_unity_spec in env.behavior_specs.items():
decision_steps, terminal_steps = env.get_steps(behavior_name)
for steps in [decision_steps, terminal_steps]:
for agent_id in steps.agent_id:
self._agent_id_to_behavior[agent_id] = behavior_name

agent_observation_specs = [
_unity_to_torchrl_spec_transform(
spec, dtype=np.dtype("float32"), device=self.device
)
for spec in behavior_unity_spec.observation_specs
]
agent_observation_spec = torch.stack(agent_observation_specs, dim=0)
observation_specs[agent_id] = agent_observation_spec

behavior_id_specs[agent_id] = UnboundedDiscreteTensorSpec(
shape=1, device=self.device, dtype=torch.int8
)
agent_id_specs[agent_id] = UnboundedDiscreteTensorSpec(
shape=1, device=self.device, dtype=torch.int8
)

(
agent_action_spec,
agent_action_mask_spec,
) = _unity_to_torchrl_spec_transform(
behavior_unity_spec.action_spec,
dtype=np.int32,
device=self.device,
)
action_specs[agent_id] = agent_action_spec
action_mask_specs[agent_id] = agent_action_mask_spec

reward_specs[agent_id] = UnboundedContinuousTensorSpec(
shape=[1], device=self.device
)
done_specs[agent_id] = DiscreteTensorSpec(
n=2, shape=[1], dtype=torch.bool, device=self.device
)
valid_mask_specs[agent_id] = DiscreteTensorSpec(
n=2, shape=[1], dtype=torch.bool, device=self.device
)

self.observation_spec = CompositeSpec(
observation=observation_specs[0],
behavior_id=behavior_id_spec,
agent_id=agent_id_spec,
)
self.action_spec, self.action_mask_spec = _unity_to_torchrl_spec_transform(
behavior_unity_spec.action_spec, dtype=np.int32, device=self.device
)
self.reward_spec = UnboundedContinuousTensorSpec(shape=[1], device=self.device)
self.done_spec = DiscreteTensorSpec(
n=2, shape=[1], dtype=torch.bool, device=self.device
observation=torch.stack(observation_specs, dim=0)
)
self.behavior_id_spec = torch.stack(behavior_id_specs, dim=0)
self.agent_id_spec = torch.stack(agent_id_specs, dim=0)
self.action_spec = torch.stack(action_specs, dim=0)
# FIXME: Support action masks
# self.action_mask_spec = torch.stack(action_mask_specs, dim=0)
self.reward_spec = torch.stack(reward_specs, dim=0)
self.done_spec = torch.stack(done_specs, dim=0)
self.valid_mask_spec = torch.stack(valid_mask_specs, dim=0)

def __repr__(self) -> str:
return (
Expand All @@ -166,27 +196,34 @@ def _check_kwargs(self, kwargs: dict):
if not isinstance(env, BaseEnv):
raise TypeError("env is not of type 'mlagents_envs.base_env.BaseEnv'.")
if "frame_skip" in kwargs and kwargs["frame_skip"] != 1:
# This functionality is difficult to support because not all agents will request
# decisions at each timestep and different agents might request decisions at
# different timesteps. This makes it difficult to do things like keep track
# of rewards.
# This functionality is difficult to support because not all agents will
# request decisions at each timestep and different agents might request
# decisions at different timesteps. This makes it difficult to do things
# like keep track of rewards.
raise ValueError(
"Currently, frame_skip is not supported for Unity environments."
)

def behavior_id_to_name(self, behavior_id: int):
raise self._behavior_names[behavior_id]
return self._behavior_names[behavior_id]

def read_reward(self, reward):
return self.reward_spec.encode(reward, ignore_device=True)
def read_obs(self, obs):
return self.observation_spec.encode({"observation": obs})

def read_obs(self, obs: np.ndarray, behavior_name: str, agent_id: int):
behavior_id = self._behavior_names.index(behavior_name)
observations = self.observation_spec.encode(
{"observation": obs, "behavior_id": behavior_id, "agent_id": agent_id},
ignore_device=True,
def read_behavior(self, behavior_name):
behavior_id = np.array(
[self._behavior_names.index(name) for name in behavior_name]
)
return observations
return self.behavior_id_spec.encode(behavior_id)

def read_agent_id(self, agent_id):
return self.agent_id_spec.encode(agent_id)

def read_reward(self, reward):
return self.reward_spec.encode(reward)

def read_valid_mask(self, valid):
return self.valid_mask_spec.encode(valid)

def read_action(self, action):
action = self.action_spec.to_numpy(action, safe=False)
Expand All @@ -195,6 +232,7 @@ def read_action(self, action):
# used for the number of agents in the game.
if action.ndim == 0:
action = np.expand_dims(action, axis=0)

if isinstance(self.action_spec, CompositeSpec):
action = self.action_spec.to_numpy(action, safe=False)
continuous_action = np.expand_dims(action["continuous"], axis=0)
Expand All @@ -218,86 +256,72 @@ def read_action_mask(self, action_mask):
def read_done(self, done):
return self.done_spec.encode(done)

def _behavior_name_update(self):
self._live_behavior_names = list(self.behavior_specs.keys())
for k in self._live_behavior_names:
if k not in self._behavior_names:
# We only add to self._behavior_names if the
# behavior name doesn't exist. This helps us
# ensure that the index of old behaviors stays
# the same and that we don't have duplicate entries.
# This is important since we use the index of the behavior
# name as an id for that behavior.
self._behavior_names.append(k)

def _batch_update(self, behavior_name):
self._current_step_idx = 0
self._current_behavior_name = behavior_name
self._decision_steps, self._terminal_steps = self.get_steps(behavior_name)

def _get_next_tensordict(self):
num_steps = len(self._decision_steps) + len(self._terminal_steps)
if self._current_step_idx >= num_steps:
raise ValueError("All agents already have actions")
done = False if self._current_step_idx < len(self._decision_steps) else True
steps = self._decision_steps if not done else self._terminal_steps
agent_id = steps.agent_id[self._current_step_idx]
step = steps[agent_id]
# FIXME: Need Tuple support here so we can support observations of varying dimensions.
# Thus, for now we use only the first observation.
obs, reward = step.obs[0], step.reward
observations = [None] * self.num_agents
behavior_name = [None] * self.num_agents
agent_ids = [None] * self.num_agents
rewards = [None] * self.num_agents
dones = [None] * self.num_agents
valid_masks = [None] * self.num_agents

for behavior_name_ in self.behavior_specs.keys():
decision_steps, terminal_steps = self.get_steps(behavior_name_)
for i, steps in enumerate([decision_steps, terminal_steps]):
for agent_id_ in steps.agent_id:
step = steps[agent_id_]

rewards[agent_id_] = step.reward
behavior_name[agent_id_] = behavior_name_
agent_ids[agent_id_] = step.agent_id
observations[agent_id_] = np.stack(step.obs, axis=0)
dones[agent_id_] = False if i == 0 else True
valid_masks[agent_id_] = True

missing_agents = set(range(self.num_agents)) - set(agent_ids)
for missing_agent in missing_agents:
observations[missing_agent] = self.observation_spec["observation"][
missing_agent
].zero()
behavior_name[missing_agent] = self._agent_id_to_behavior[missing_agent]
agent_ids[missing_agent] = missing_agent
rewards[missing_agent] = self.reward_spec[missing_agent].zero()
dones[missing_agent] = self.done_spec[missing_agent].zero()
valid_masks[missing_agent] = False

tensordict_out = TensorDict(
source=self.read_obs(
obs, behavior_name=self._current_behavior_name, agent_id=agent_id
),
source=self.read_obs(np.stack(observations, axis=0)),
batch_size=self.batch_size,
device=self.device,
)
# tensordict_out.set("action_mask", self.read_action_mask(action_mask))
tensordict_out.set("reward", self.read_reward(reward))
tensordict_out.set("done", self.read_done(done))
tensordict_out.set("behavior_id", self.read_behavior(behavior_name))
tensordict_out.set("agent_id", self.read_agent_id(np.stack(agent_ids, axis=0)))
tensordict_out.set("reward", self.read_reward(np.stack(rewards, axis=0)))
tensordict_out.set("done", self.read_done(np.stack(dones, axis=0)))
tensordict_out.set(
"valid_mask", self.read_valid_mask(np.stack(valid_masks, axis=0))
)
return tensordict_out

def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
# We step through each agent one at a time, and only perform
# an environment step (send the actions to the environment)
# once all agents are processed. This is because Unity requires
# all agents to have actions before stepping or else the zero
# action will be sent.
#
# In order to step through each agent, we first iterate through
# all behaviors, and determine actions for each agent in that behavior
# and then repeat the loop until all behavior's and their
# agents are accounted for. We then perform an environment step.
action = tensordict.get("action")
unity_action = self.read_action(action)
self.set_action_for_agent(
self._current_behavior_name, tensordict.get("agent_id").item(), unity_action
eligible_agent_mask = torch.logical_and(
tensordict["valid_mask"], torch.logical_not(tensordict["done"])
)
self._current_step_idx += 1
try:
tensordict_out = self._get_next_tensordict()
return tensordict_out.select().set("next", tensordict_out)
except ValueError:
behavior_id = self._live_behavior_names.index(self._current_behavior_name)
# If we have more behaviors to go through, keep continuing. Otherwise step the environment and
# then continue again.
if behavior_id < len(self._live_behavior_names) - 1:
self._current_behavior_name = self._live_behavior_names[behavior_id + 1]
self._batch_update(self._current_behavior_name)
tensordict_out = self._get_next_tensordict()
return tensordict_out.select().set("next", tensordict_out)
else:
self._env.step()
self._behavior_name_update()
self._batch_update(self._live_behavior_names[0])
tensordict_out = self._get_next_tensordict()
return tensordict_out.select().set("next", tensordict_out)
behavior_ids = tensordict["behavior_id"][eligible_agent_mask]
agent_ids = tensordict["agent_id"][eligible_agent_mask]
actions = tensordict["action"].unsqueeze(-1)[eligible_agent_mask]
for action, behavior_id, agent_id in zip(actions, behavior_ids, agent_ids):
unity_action = self.read_action(action)
self.set_action_for_agent(
self.behavior_id_to_name(behavior_id.item()),
agent_id.item(),
unity_action,
)
self._env.step()
tensordict_out = self._get_next_tensordict()
return tensordict_out.select().set("next", tensordict_out)

def _reset(self, tensordict: TensorDictBase | None = None, **kwargs):
self._env.reset(**kwargs)
self._behavior_name_update()
self._batch_update(self._live_behavior_names[0])
return self._get_next_tensordict()


Expand Down

0 comments on commit e2466bd

Please sign in to comment.