diff --git a/notebooks/registry_tutorial.ipynb b/notebooks/registry_tutorial.ipynb new file mode 100644 index 00000000..cccf4e38 --- /dev/null +++ b/notebooks/registry_tutorial.ipynb @@ -0,0 +1,568 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + }, + "gpuClass": "standard" + }, + "cells": [ + { + "cell_type": "markdown", + "source": [ + "# Introduction to Hive Registry\n", + "In this tutorial, we will explain how to register classes and objects into the RLHive framework.\n", + "\n", + "The registry module `hive.utils.registry` is used to register classes (corresponding to the agent, environment, logger, or runner) in the RLHive Registry. In other words, it allows you to register different types of `Registrable` classes and objects, and generates constructors for those classes in the form of `get_{type_name}`. These constructors allow you to construct objects from dictionary configs. These configs should have two fields: \n", + "\n", + "\n", + "1. `name` - name used when registering a class in the registry\n", + "2. `**kwargs` - keyword arguments that will be passed to the constructor of the object\n", + "\n", + "\n", + "These constructors can also build objects recursively, i.e. if a config contains the config for another `Registrable` object, this will be automatically created before being passed to the constructor of the original object. These constructors also allow you to directly specify/override arguments for object constructors directly from the command line. These parameters are specified in dot `.` notation. They also are able to handle lists and dictionaries of Registrable objects.\n", + "\n", + "For example, let’s consider the following scenario: Your `agent` class has an argument `arg1` which is annotated to be `List[Class1]` ( where `Class1` is `Registrable`), and the `Class1` constructor takes an argument `arg2`. In the passed YAML config, there are two different `Class1` object configs listed. The constructor will check to see if both `agent.arg1.0.arg2` and `agent.arg1.1.arg2` have been passed.\n", + "\n", + "The parameters passed in the command line will be parsed according to the type annotation of the corresponding low level constructor. If it is not one of `int`, `float`, `str`, or `bool`, it simply loads the string into python using a YAML loader.\n", + "\n", + "Each constructor returns the object, as well a dictionary config with all the parameters used to create the object and any `Registrable` objects created in the process of creating this object." + ], + "metadata": { + "id": "TiiOCUnwDlXB" + } + }, + { + "cell_type": "code", + "source": [ + "## used for updating config.yaml files \n", + "%%capture\n", + "!pip install ruamel.yaml\n", + "!pip install pyglet\n", + "\n", + "!pip install git+https://github.com/chandar-lab/RLHive.git@dev" + ], + "metadata": { + "id": "rKp21vrIVavt" + }, + "execution_count": 27, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "%%capture\n", + "!apt-get install x11-utils > /dev/null 2>&1 \n", + "!pip install pyglet > /dev/null 2>&1 \n", + "!apt-get install -y xvfb python-opengl > /dev/null 2>&1\n", + "!pip install gym pyvirtualdisplay > /dev/null 2>&1" + ], + "metadata": { + "id": "4r0gfzupVapd" + }, + "execution_count": 2, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "## Required imports\n", + "import hive\n", + "import torch\n", + "from hive.utils.utils import Registrable" + ], + "metadata": { + "id": "M6R2hd6KnJy9" + }, + "execution_count": 3, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "### Registering an Environment\n", + "\n", + "Consider registering a custom environment class named `Grid` (which inherits `BaseEnv`) in the RLHive registry. " + ], + "metadata": { + "id": "YrlRz676D_Dm" + } + }, + { + "cell_type": "code", + "source": [ + "from hive.envs.base import BaseEnv\n", + "class Grid(BaseEnv):\n", + " def __init__(self, env_name = 'Grid', **kwargs):\n", + " pass\n", + " def reset(self):\n", + " pass\n", + " def step(self):\n", + " pass\n", + " def render(self):\n", + " pass\n", + " def close(self):\n", + " pass\n", + " def save(self):\n", + " pass" + ], + "metadata": { + "id": "PbefnmFLJprJ" + }, + "execution_count": 4, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "`registry.register` is used to register the class of interest. The parameters of the method are as follows.\n", + "- `name` *(str)* - Name of the class/object being registered\n", + "- `constructor` *(callable)* - Callable that will be passed all kwargs from configs and be analyzed to get type annotations\n", + "- `type` - *(type)* - Type of class/object being registered. Should be subclass of `Registrable`" + ], + "metadata": { + "id": "zr8mLhSN8gAd" + } + }, + { + "cell_type": "code", + "source": [ + "from hive.utils.registry import registry\n", + "registry.register(name = 'Grid', \n", + " constructor = Grid, \n", + " type = BaseEnv)" + ], + "metadata": { + "id": "J3DGzZgP-u28" + }, + "execution_count": 5, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "More than one environment can be registered at once using the `registry.register_all` method. The parameters of the function are as follows.\n", + "\n", + "- `base_class` *(type)* - Corresponds to the type of the register function\n", + "- `class_dict` *(dict[str, callable])* - A dictionary mapping from name to constructor\n", + "\n", + "Consider registering three environments, `Gridv1`, `Gridv2`, and `Gridv3` in the RLHive registry. " + ], + "metadata": { + "id": "sOy_VsfPLAIp" + } + }, + { + "cell_type": "code", + "source": [ + "class Gridv1(BaseEnv):\n", + " def __init__(self, env_name = 'Gridv1', **kwargs):\n", + " pass\n", + "class Gridv2(BaseEnv):\n", + " def __init__(self, env_name = 'Gridv2', **kwargs):\n", + " pass\n", + "class Gridv3(BaseEnv):\n", + " def __init__(self, env_name = 'Gridv3', **kwargs):\n", + " pass\n", + " \n", + "registry.register_all(\n", + " BaseEnv,\n", + " {\n", + " \"Gridv1\": Gridv1,\n", + " \"Gridv2\": Gridv2,\n", + " \"Gridv3\": Gridv3,\n", + " },\n", + ")" + ], + "metadata": { + "id": "Ud2IRJzUK_wI" + }, + "execution_count": 6, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "### Registering an Agent\n", + "\n", + "Consider registering a custom environment class named `LearningAgent` (which inherits `Agent` class) in the RLHive registry. " + ], + "metadata": { + "id": "LPEtPLw2MrD_" + } + }, + { + "cell_type": "code", + "source": [ + "from hive.agents.agent import Agent\n", + "\n", + "class LearningAgent(Agent):\n", + " def __init__(self):\n", + " pass\n", + " def act(self):\n", + " pass" + ], + "metadata": { + "id": "I9TrG7DFMq4E" + }, + "execution_count": 7, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "from hive.utils.registry import registry\n", + "registry.register(name = 'LearningAgent', \n", + " constructor = LearningAgent, \n", + " type = Agent)" + ], + "metadata": { + "id": "slMh_1bR-s1a" + }, + "execution_count": 8, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "More than one agent can be registered at once using the `register_all` method. Consider registering three environments, `LearningAgentV1`, `LearningAgentV2`, and `LearningAgentV3` in the RLHive registry." + ], + "metadata": { + "id": "kdcN7NZlT6dD" + } + }, + { + "cell_type": "code", + "source": [ + "class LearningAgentV1(Agent):\n", + " def __init__(self):\n", + " pass\n", + " def act(self):\n", + " pass\n", + "class LearningAgentV2(Agent):\n", + " def __init__(self):\n", + " pass\n", + " def act(self):\n", + " pass\n", + "class LearningAgentV3(Agent):\n", + " def __init__(self):\n", + " pass\n", + " def act(self):\n", + " pass\n", + "\n", + "registry.register_all(\n", + " Agent,\n", + " {\n", + " \"LearningAgentV1\": LearningAgentV1,\n", + " \"LearningAgentV2\": LearningAgentV2,\n", + " \"LearningAgentV3\": LearningAgentV3,\n", + " },\n", + ")" + ], + "metadata": { + "id": "WUDq9CoJMYXf" + }, + "execution_count": 9, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "### Registering a Logger" + ], + "metadata": { + "id": "Ku9Jz5EKbmRE" + } + }, + { + "cell_type": "markdown", + "source": [ + "Consider registering a custom environment class named `CustomLogger` (which inherits `Logger` class) in the RLHive registry. " + ], + "metadata": { + "id": "ureIVBKcbtZJ" + } + }, + { + "cell_type": "code", + "source": [ + "from hive.utils.loggers import Logger\n", + "\n", + "class CustomLogger(Logger):\n", + " def __init__(self):\n", + " pass\n", + " def update_step(self, timescale):\n", + " pass" + ], + "metadata": { + "id": "70ys0SvobnkD" + }, + "execution_count": 10, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "from hive.utils.registry import registry\n", + "registry.register(name = 'CustomLogger', \n", + " constructor = CustomLogger, \n", + " type = Logger)" + ], + "metadata": { + "id": "otDqYlUWcDkn" + }, + "execution_count": 11, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "More than one logger can be registered at once using the `register_all` method. Consider registering three environments, `CustomLoggerV1`, `CustomLoggerV2`, and `CustomLoggerV3` in the RLHive registry." + ], + "metadata": { + "id": "A8ebt6CvdYLX" + } + }, + { + "cell_type": "code", + "source": [ + "class CustomLoggerV1(Logger):\n", + " def __init__(self):\n", + " pass\n", + " def update_step(self, timescale):\n", + " pass\n", + "class CustomLoggerV2(Logger):\n", + " def __init__(self):\n", + " pass\n", + " def update_step(self, timescale):\n", + " pass\n", + "class CustomLoggerV3(Logger):\n", + " def __init__(self):\n", + " pass\n", + " def update_step(self, timescale):\n", + " pass\n", + "\n", + "registry.register_all(\n", + " Logger,\n", + " {\n", + " \"CustomLoggerV1\": CustomLoggerV1,\n", + " \"CustomLoggerV2\": CustomLoggerV2,\n", + " \"CustomLoggerV3\": CustomLoggerV3,\n", + " },\n", + ")" + ], + "metadata": { + "id": "PZZbEC6QdX-o" + }, + "execution_count": 12, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "### Registering with a Custom Data Type" + ], + "metadata": { + "id": "S5STl9CCYfOy" + } + }, + { + "cell_type": "markdown", + "source": [ + "#### Registering Initialization Function with Custom Data Type. \n", + "\n", + "In this example, a custom initialization function `variance_scaling_` is defined below.\n", + "\n", + "\n", + "" + ], + "metadata": { + "id": "XL-EVHj3cweG" + } + }, + { + "cell_type": "code", + "source": [ + "import math\n", + "def variance_scaling_(tensor, scale=1.0, mode=\"fan_in\", distribution=\"uniform\"):\n", + " \"\"\"Implements the :py:class:`tf.keras.initializers.VarianceScaling`\n", + " initializer in PyTorch.\n", + " Args:\n", + " tensor (torch.Tensor): Tensor to initialize.\n", + " scale (float): Scaling factor (must be positive).\n", + " mode (str): Must be one of `\"fan_in\"`, `\"fan_out\"`, and `\"fan_avg\"`.\n", + " distribution: Random distribution to use, must be one of\n", + " \"truncated_normal\", \"untruncated_normal\" and \"uniform\".\n", + " Returns:\n", + " Initialized tensor.\n", + " \"\"\"\n", + " fan = calculate_correct_fan(tensor, mode)\n", + " scale /= fan\n", + " if distribution == \"truncated_normal\":\n", + " stddev = math.sqrt(scale) / 0.87962566103423978\n", + " return torch.nn.init.trunc_normal_(tensor, 0.0, stddev, -2 * stddev, 2 * stddev)\n", + " elif distribution == \"untruncated_normal\":\n", + " stddev = math.sqrt(scale)\n", + " return torch.nn.init.normal_(tensor, 0.0, stddev)\n", + " elif distribution == \"uniform\":\n", + " limit = math.sqrt(3.0 * scale)\n", + " return torch.nn.init.uniform_(tensor, -limit, limit)\n", + " else:\n", + " raise ValueError(f\"Distribution {distribution} not supported\")\n", + "\n", + "def calculate_correct_fan(tensor, mode):\n", + " \"\"\"Calculate fan of tensor.\n", + " Args:\n", + " tensor (torch.Tensor): Tensor to calculate fan of.\n", + " mode (str): Which type of fan to compute. Must be one of `\"fan_in\"`,\n", + " `\"fan_out\"`, and `\"fan_avg\"`.\n", + " Returns:\n", + " Fan of the tensor based on the mode.\n", + " \"\"\"\n", + " fan_in, fan_out = torch.nn.init._calculate_fan_in_and_fan_out(tensor)\n", + " if mode == \"fan_in\":\n", + " return fan_in\n", + " elif mode == \"fan_out\":\n", + " return fan_out\n", + " elif mode == \"fan_avg\":\n", + " return (fan_in + fan_out) / 2\n", + " else:\n", + " raise ValueError(f\"Fan mode {mode} not supported\")" + ], + "metadata": { + "id": "1ECMXZlJykXY" + }, + "execution_count": 13, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "The cell below demonstrates how to register `variance_scaling_` (and other standard initialization functions) with a custom data type `InitializationFn`." + ], + "metadata": { + "id": "u_ek1YFXRDqL" + } + }, + { + "cell_type": "code", + "source": [ + "class InitializationFn(Registrable):\n", + " \"\"\"A wrapper for callables that produce initialization functions.\n", + " These wrapped callables can be partially initialized through configuration\n", + " files or command line arguments.\n", + " \"\"\"\n", + "\n", + " @classmethod\n", + " def type_name(cls):\n", + " \"\"\"\n", + " Returns:\n", + " \"init_fn\"\n", + " \"\"\"\n", + " return \"init_fn\"\n", + "\n", + "\n", + "registry.register_all(\n", + " InitializationFn,\n", + " {\n", + " \"uniform\": torch.nn.init.uniform_,\n", + " \"normal\": torch.nn.init.normal_,\n", + " \"constant\": torch.nn.init.constant_,\n", + " \"ones\": torch.nn.init.ones_,\n", + " \"zeros\": torch.nn.init.zeros_,\n", + " \"eye\": torch.nn.init.eye_,\n", + " \"dirac\": torch.nn.init.dirac_,\n", + " \"xavier_uniform\": torch.nn.init.xavier_uniform_,\n", + " \"xavier_normal\": torch.nn.init.xavier_normal_,\n", + " \"kaiming_uniform\": torch.nn.init.kaiming_uniform_,\n", + " \"kaiming_normal\": torch.nn.init.kaiming_normal_,\n", + " \"orthogonal\": torch.nn.init.orthogonal_,\n", + " \"sparse\": torch.nn.init.sparse_,\n", + " \"variance_scaling\": variance_scaling_,\n", + " },\n", + ")" + ], + "metadata": { + "id": "5IXvPxacc0VD" + }, + "execution_count": 23, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "get_init_fn = getattr(registry, f\"get_{InitializationFn.type_name()}\")" + ], + "metadata": { + "id": "hC9TJ9vRl7Gj" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "#### Registering Optimizer Function with Custom Data Type\n", + "\n", + "Similar to the previous example, we can also register optimizer functions with a custom data type. \n" + ], + "metadata": { + "id": "3nlBl_bvu7ec" + } + }, + { + "cell_type": "code", + "source": [ + "class OptimizationFn(Registrable):\n", + " \"\"\"A wrapper for callables for optimization functions.\n", + " These wrapped callables can be partially initialized through configuration\n", + " files or command line arguments.\n", + " \"\"\"\n", + "\n", + " @classmethod\n", + " def type_name(cls):\n", + " \"\"\"\n", + " Returns:\n", + " \"opt_fn\"\n", + " \"\"\"\n", + " return \"opt_fn\"\n", + "\n", + "\n", + "registry.register_all(\n", + " OptimizationFn,\n", + " {\n", + " \"Adam\": torch.optim.Adam,\n", + " \"SGD\" : torch.optim.SGD,\n", + " \"RMSprop\" : torch.optim.RMSprop,\n", + " \"Adagrad\" : torch.optim.Adagrad\n", + " },\n", + ")" + ], + "metadata": { + "id": "zdaPz7qpu-CK" + }, + "execution_count": 16, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "get_optimizer_fn = getattr(registry, f\"get_{OptimizationFn.type_name()}\")" + ], + "metadata": { + "id": "uv2XsFMsxok_" + }, + "execution_count": null, + "outputs": [] + } + ] +}