diff --git a/docs/quickstart.rst b/docs/quickstart.rst index 7d6efa6f..a51afc9e 100644 --- a/docs/quickstart.rst +++ b/docs/quickstart.rst @@ -56,8 +56,8 @@ sure you look at the examples hive_single_agent_loop -c hive_multi_agent_loop -c -Finally, if instead you want to use your own custom custom components you can -simply register it with RLHive and run your config normally: +Finally, if instead you want to use your own custom components you can +simply register it with RLHive and run your config in the following way: .. code-block:: python diff --git a/docs/tutorials/agent_tutorial.rst b/docs/tutorials/agent_tutorial.rst index 624b46d2..5d394e16 100644 --- a/docs/tutorials/agent_tutorial.rst +++ b/docs/tutorials/agent_tutorial.rst @@ -32,6 +32,7 @@ First, we define the constructor: self._q_values = np.zeros(obs_dim, act_dim) self._gamma = gamma self._alpha = alpha + self._act_dim = act_dim self._epsilon_schedule = LinearSchedule(1.0, final_epsilon, explore_steps) In this constructor, we created a numpy array to keep track of the Q-values for every diff --git a/docs/tutorials/configuration_tutorial.rst b/docs/tutorials/configuration_tutorial.rst index e84a74b9..22256a13 100644 --- a/docs/tutorials/configuration_tutorial.rst +++ b/docs/tutorials/configuration_tutorial.rst @@ -36,8 +36,8 @@ In this example, :py:class:`~hive.agents.dqn_agent.DQNAgent` , :py:class:`~hive.agents.qnets.mlp.MLPNetwork` , and :py:class:`~hive.replays.circular_replay.CircularReplayBuffer` are all classes registered with RLHive. Thus, we can do this configuration directly. When the -``registry`` getter function for agents, -:py:meth:`~hive.utils.registry.Registry.get_agent` is then called with this config +``registry`` getter function for agents +:py:meth:`~hive.utils.registry.Registry.get_agent`, is then called with this config dictionary (with the missing required arguments such as ``obs_dim`` and ``act_dim``, filled in), it will build all the inner RLHive objects automatically. This works by using the type annotations on the constructors of the objects, so @@ -51,8 +51,8 @@ Overriding from command lines -------------------------------- When using the ``registry`` getter functions, RLHive automatically checks any command line arguments passed to see if they match/override any default or yaml configured -arguments. With ``getter`` functionyou provide a config and a prefix. That prefix -is added prepended to any argument names when searching the command line. For example, +arguments. With ``getter`` function you provide a config and a prefix. That prefix +is added, prepended to any argument names when searching the command line. For example, with the above config, if it were loaded and the :py:meth:`~hive.utils.registry.Registry.get_agent` method was called as follows: @@ -65,7 +65,7 @@ python script: ``--ag.discount_rate .95``. This can go arbitrarily deep into reg RLHive class. For example, if you wanted to change the capacity of the replay buffer, you could pass ``--ag.replay_buffer.capacity 100000``. -If the type annotation the argument ``arg`` is ``List[C]`` where C is a registered +If the type annotation of the argument ``arg`` is ``List[C]`` where C is a registered RLHive class, then you can override the argument of an individual object, ``foo``, configured through YAML by passing ``--arg.0.foo ``. diff --git a/docs/tutorials/env_tutorial.rst b/docs/tutorials/env_tutorial.rst index af6544f9..637520a0 100644 --- a/docs/tutorials/env_tutorial.rst +++ b/docs/tutorials/env_tutorial.rst @@ -22,7 +22,7 @@ Creating an Environment RLHive Environments ^^^^^^^^^^^^^^^^^^^ -Every environment used in RLHive should be a subclass of `~hive.envs.base.BaseEnv`. +Every environment used in RLHive should be a subclass of :py:class:`~hive.envs.base.BaseEnv`. It should provide a ``reset`` function that resets the environment to a new episode and returns a tuple of ``(observation, turn)`` and a ``step`` function that takes in an action, performs the step in the environment, and returns a tuple of diff --git a/docs/tutorials/runner_tutorial.rst b/docs/tutorials/runner_tutorial.rst index 793ca331..1d1513b2 100644 --- a/docs/tutorials/runner_tutorial.rst +++ b/docs/tutorials/runner_tutorial.rst @@ -6,7 +6,7 @@ We provide two different :py:class:`~hive.runners.base.Runner` classes: for both Runner classes can be viewed in their respective files with the :py:meth:`set_up_experiment` functions. The :py:meth:`~hive.utils.registry.get_parsed_args` function can be used -to get any arguments from the command line are not part of the signatures +to get any arguments from the command line that are not part of the signatures of already registered RLHive class constructors. diff --git a/hive/agents/ddpg.py b/hive/agents/ddpg.py index 3cea236f..d7d361e9 100644 --- a/hive/agents/ddpg.py +++ b/hive/agents/ddpg.py @@ -50,18 +50,18 @@ def __init__( None, defaults to :py:class:`~torch.nn.Identity`. actor_net (FunctionApproximator): The network that takes the encoded observations from representation_net and outputs the representations - used to compute the actions (ie everything except the last layer). + used to compute the actions (i.e. everything except the last layer). critic_net (FunctionApproximator): The network that takes two inputs: the encoded observations from representation_net and actions. It outputs - the representations used to compute the values of the actions (ie + the representations used to compute the values of the actions (i.e. everything except the last layer). init_fn (InitializationFn): Initializes the weights of agent networks using create_init_weights_fn. actor_optimizer_fn (OptimizerFn): A function that takes in the list of - parameters of the actor returns the optimizer for the actor. If None, + parameters of the actor and returns the optimizer for the actor. If None, defaults to :py:class:`~torch.optim.Adam`. critic_optimizer_fn (OptimizerFn): A function that takes in the list of - parameters of the critic returns the optimizer for the critic. If None, + parameters of the critic and returns the optimizer for the critic. If None, defaults to :py:class:`~torch.optim.Adam`. critic_loss_fn (LossFn): The loss function used to optimize the critic. If None, defaults to :py:class:`~torch.nn.MSELoss`. diff --git a/hive/agents/rainbow.py b/hive/agents/rainbow.py index a386e19f..61353f4e 100644 --- a/hive/agents/rainbow.py +++ b/hive/agents/rainbow.py @@ -257,6 +257,7 @@ def act(self, observation): def update(self, update_info): """ Updates the DQN agent. + Args: update_info: dictionary containing all the necessary information to update the agent. Should contain a full transition, with keys for diff --git a/hive/agents/td3.py b/hive/agents/td3.py index 941d53f0..4f7c948d 100644 --- a/hive/agents/td3.py +++ b/hive/agents/td3.py @@ -62,10 +62,10 @@ def __init__( None, defaults to :py:class:`~torch.nn.Identity`. actor_net (FunctionApproximator): The network that takes the encoded observations from representation_net and outputs the representations - used to compute the actions (ie everything except the last layer). + used to compute the actions (i.e. everything except the last layer). critic_net (FunctionApproximator): The network that takes two inputs: the encoded observations from representation_net and actions. It outputs - the representations used to compute the values of the actions (ie + the representations used to compute the values of the actions (i.e. everything except the last layer). init_fn (InitializationFn): Initializes the weights of agent networks using create_init_weights_fn. diff --git a/hive/replays/circular_replay.py b/hive/replays/circular_replay.py index 4c2a7e01..a090c2e0 100644 --- a/hive/replays/circular_replay.py +++ b/hive/replays/circular_replay.py @@ -43,9 +43,6 @@ def __init__( a numpy type, a string of the form np.uint8 or numpy.uint8 is acceptable. action_shape: Shape of actions that will be stored in the buffer. - action_dtype: Type of actions that will be stored in the buffer. Format is - described in the description of observation_dtype. - action_shape: Shape of actions that will be stored in the buffer. action_dtype: Type of actions that will be stored in the buffer. Format is described in the description of observation_dtype. reward_shape: Shape of rewards that will be stored in the buffer. diff --git a/hive/replays/prioritized_replay.py b/hive/replays/prioritized_replay.py index 5b004d0a..b5962a5c 100644 --- a/hive/replays/prioritized_replay.py +++ b/hive/replays/prioritized_replay.py @@ -46,9 +46,6 @@ def __init__( a numpy type, a string of the form np.uint8 or numpy.uint8 is acceptable. action_shape: Shape of actions that will be stored in the buffer. - action_dtype: Type of actions that will be stored in the buffer. Format is - described in the description of observation_dtype. - action_shape: Shape of actions that will be stored in the buffer. action_dtype: Type of actions that will be stored in the buffer. Format is described in the description of observation_dtype. reward_shape: Shape of rewards that will be stored in the buffer. diff --git a/hive/runners/utils.py b/hive/runners/utils.py index f9a089c1..5e101687 100644 --- a/hive/runners/utils.py +++ b/hive/runners/utils.py @@ -17,7 +17,7 @@ def load_config( logger_config=None, ): """Used to load config for experiments. Agents, environment, and loggers components - in main config file can be overrided based on other log files. + in main config file can be overriden based on other log files. Args: config (str): Path to configuration file. Either this or :obj:`preset_config` diff --git a/hive/utils/experiment.py b/hive/utils/experiment.py index cb093f01..65c57b63 100644 --- a/hive/utils/experiment.py +++ b/hive/utils/experiment.py @@ -72,6 +72,7 @@ def should_save(self): def save(self, tag="current"): """Saves the experiment. + Args: tag (str): Tag to prefix the folder. """ diff --git a/hive/utils/registry.py b/hive/utils/registry.py index 0b99cb3a..ab5f04c2 100644 --- a/hive/utils/registry.py +++ b/hive/utils/registry.py @@ -40,7 +40,7 @@ class Registry: For example, let's consider the following scenario: Your agent class has an argument `arg1` which is annotated to be `List[Class1]`, `Class1` is `Registrable`, and the `Class1` constructor takes an argument `arg2`. - In the passed yml config, there are two different Class1 object configs listed. + In the passed yml config, there are two different Class1 object configs listed, the constructor will check to see if both `--agent.arg1.0.arg2` and `--agent.arg1.1.arg2` have been passed.