From bf5afbac1d592c53f5d1f0f1f0fb6ee20b2d6010 Mon Sep 17 00:00:00 2001
From: dpressel <dpressel@gmail.com>
Date: Mon, 25 Jul 2022 00:23:24 -0400
Subject: [PATCH 1/4] add OPT

---
 src/mint/opt.py | 356 ++++++++++++++++++++++++++++++++++++++++++++++++
 1 file changed, 356 insertions(+)
 create mode 100644 src/mint/opt.py

diff --git a/src/mint/opt.py b/src/mint/opt.py
new file mode 100644
index 0000000..4193b6c
--- /dev/null
+++ b/src/mint/opt.py
@@ -0,0 +1,356 @@
+import torch
+import torch.nn as nn
+import os
+from typing import Optional
+from mint.common import WeightTiedVocabProjection
+from mint.preln import PreLayerNormTransformerEncoder
+from mint.postln import TransformerEncoder
+import logging
+
+logger = logging.getLogger("mint")
+
+class OPTLearnedPositionalEmbedding(nn.Module):
+    """Learned positional embeddings for OPT
+
+    The embeddings are a combination of 2 inputs, word embeddings and positional embeddings
+    The word embeddings is a learned vector that uses the word one-hots to convert to a dense representation.
+    Each of these embeddings are added together in the forward
+    """
+
+    OPT_POS_OFFSET = 2
+
+    def __init__(
+            self,
+            vocab_dim: int,
+            hidden_dim: int = 768,
+            padding_idx: int = 0,
+            max_seq_len: int = 2048,
+    ):
+        super().__init__()
+        self.word_embeddings = nn.Embedding(vocab_dim, hidden_dim, padding_idx)
+        self.position_embeddings = nn.Embedding(
+            max_seq_len + OPTLearnedPositionalEmbedding.OPT_POS_OFFSET, hidden_dim
+        )
+
+    def forward(
+            self, x: torch.Tensor, token_type: Optional[torch.Tensor] = None
+    ) -> torch.Tensor:
+        """Takes a tensor of shape `[B, T]` and an optional `token_type` of same shape
+
+        :param x: A tensor of word one-hots, shape `[B, T]`
+        :param token_type: Ignored for BART!
+        :return: The sum of the positional and word embeddings
+        """
+        embed = self.word_embeddings(x)
+
+        position = self.position_embeddings(
+            torch.arange(x.shape[-1], dtype=x.dtype).to(x.device)
+            + OPTLearnedPositionalEmbedding.OPT_POS_OFFSET
+        ).unsqueeze(0)
+
+        return embed + position
+
+    @property
+    def weight(self):
+        """Access word_embeddings weights
+
+        :return: The word_embeddings weights
+        """
+        return self.word_embeddings.weight
+
+
+class OPTTransformerLM(PreLayerNormTransformerEncoder):
+    """OPT LM predicts tokens from left-to-right, with pre-layer-norm encoders
+    """
+
+    def __init__(
+        self,
+        vocab_size: int,
+        padding_idx: int = 0,
+        hidden_size: int = 768,
+        num_heads: int = 12,
+        num_layers: int = 12,
+        dropout: float = 0.1,
+        layer_norm_eps: float = 1e-12,
+        activation: nn.Module = nn.ReLU(),
+        feed_forward_size: Optional[int] = None,
+        max_seq_len: int = 2048,
+        **kwargs,
+    ):
+        super().__init__(
+            OPTLearnedPositionalEmbedding,
+            vocab_size,
+            padding_idx,
+            hidden_size,
+            num_heads,
+            num_layers,
+            dropout,
+            layer_norm_eps,
+            activation,
+            feed_forward_size,
+            max_seq_len,
+        )
+        self.activation = activation
+
+        self.register_buffer(
+            "causal_mask",
+            torch.tril(
+                torch.ones(
+                    (
+                        max_seq_len,
+                        max_seq_len,
+                    ),
+                    dtype=torch.uint8,
+                )
+            )
+            .unsqueeze(0)
+            .unsqueeze(0),
+        )
+
+        self.output_layer = WeightTiedVocabProjection(self.embeddings.word_embeddings)
+        self.apply(self.init_layer_weights)
+
+    def create_loss(self):
+        return nn.CrossEntropyLoss(ignore_index=0)
+
+    def forward(
+        self,
+        x: torch.Tensor,
+        mask: Optional[torch.Tensor] = None,
+        token_type: Optional[torch.Tensor] = None,
+    ) -> torch.Tensor:
+        """Apply the encoder from the parent, followed by penultimate and output projection
+
+        :param x: A one-hot (long) tensor of shape `[B, T]`
+        :param mask: An optional mask to take in for attention
+        :param token_type: An optional tensor of 0 or 1, shape `[B, T]`
+        :return:
+        """
+        input_mask = self.causal_mask[:, :, : x.shape[1], : x.shape[1]]
+        if mask is not None:
+            input_mask = mask & input_mask.to(dtype=torch.bool)
+
+        y = super().forward(x, input_mask)
+        y = self.output_layer(y)
+        return y
+
+
+class OPTTransformerPooledEncoder(PreLayerNormTransformerEncoder):
+    """Use our Transformer encoder with a pooling head.
+
+    We will use this model for classification
+    """
+
+    def __init__(
+        self,
+        vocab_size: int,
+        padding_idx: int = 0,
+        hidden_size: int = 768,
+        num_heads: int = 12,
+        num_layers: int = 12,
+        dropout: float = 0.1,
+        layer_norm_eps: float = 1e-12,
+        activation: nn.Module = nn.ReLU(),
+        feed_forward_size: Optional[int] = None,
+        output: Optional[nn.Module] = None,
+        max_seq_len: int = 2048,
+        pool_id: Optional[int] = None,
+        **kwargs,
+    ):
+        """Set up initialization for a (post-layer-norm) Transformer with pooling output.  Defaults to bert-base settings
+
+        :param vocab_size: The size of the input vocabulary
+        :param padding_idx: The padding index, defaults to 0
+        :param hidden_size: The number of hidden units
+        :param num_heads: The number of heads for multi-headed attn.  Should divide evenly into hidden_size
+        :param num_layers: The number of transformer layers (MHA+FFN) in the architecture
+        :param dropout: The value to apply for dropout
+        :param layer_norm_eps: The noising term for layer norm
+        :param activation: The activation function to use throughout
+        :param feed_forward_size: An optional value to set for the FFN MLP output size, defaults to 4*hidden_size
+        :param output: An optional projection layer to apply at the end
+        :param max_seq_len: The maximum seq len, for GPT2 this should be 1024
+        :param pool_id: An optional integer value to use for the pooling token.  If not set, we use mean pooling
+        """
+        super().__init__(
+            OPTLearnedPositionalEmbedding,
+            vocab_size,
+            padding_idx,
+            hidden_size,
+            num_heads,
+            num_layers,
+            dropout,
+            layer_norm_eps,
+            activation,
+            feed_forward_size,
+            max_seq_len,
+        )
+
+        self.pooling = self.mean_pool if pool_id is None else self.pool_by_id
+        self.pool_id = pool_id
+
+        self.output = output if output else nn.Identity()
+        self.register_buffer(
+            "causal_mask",
+            torch.tril(
+                torch.ones(
+                    (
+                        max_seq_len,
+                        max_seq_len,
+                    ),
+                    dtype=torch.uint8,
+                )
+            )
+            .unsqueeze(0)
+            .unsqueeze(0),
+        )
+
+        self.apply(self.init_layer_weights)
+
+    def forward(
+        self,
+        x: torch.Tensor,
+        mask: Optional[torch.Tensor] = None,
+        token_type: Optional[torch.Tensor] = None,
+    ) -> torch.Tensor:
+        """
+
+        :param x: A one-hot (long) tensor of shape `[B, T]`
+        :param mask: An optional mask to take in for attention
+        :param token_type:
+        :return:
+        """
+        input_mask = self.causal_mask[:, :, : x.shape[1], : x.shape[1]]
+        if mask is not None:
+            input_mask = mask & input_mask.to(dtype=torch.bool)
+
+        y = self.embeddings(x, token_type)
+        for t in self.encoder:
+            y = t(y, input_mask)
+
+        y = self.pooling(x, y)
+        return self.output(y)
+
+    def pool_by_id(self, inputs, embeddings):
+        return embeddings[inputs == self.pool_id]
+
+    def mean_pool(self, inputs, embeddings):
+        mask = inputs != self.padding_idx
+        seq_lengths = mask.sum(1).float()
+        embeddings = embeddings.masked_fill(mask.unsqueeze(-1) == False, 0.0)
+        return embeddings.sum(1) / seq_lengths.unsqueeze(-1)
+
+
+class OPTCreator:
+    @classmethod
+    def convert_state_dict(cls, tlm, bert_state_dict):
+        """Convert the state dict to TFS compatible names
+
+        The encoder token embeddings (AKA word_embeddings) are shared with the decoder token embeddings, and
+        in the HF implementation, this is done via `self.shared` so all 3 items are in the original checkpoint,
+        and we only need one of them.  We have tied these together by assignment already, so loading the encoder's
+        word embeddings updates the decoder word embeddings too
+
+        Note that the positional embeddings are different for encoder and decoder, so these are not shared and both
+        are loaded
+
+        :param tlm:
+        :param bert_state_dict:
+        :return:
+        """
+        tlm_field_names = set(k for k in tlm.state_dict().keys())
+        hf_field_names = bert_state_dict.keys()
+
+        """
+        Unset params: {'layer_norm.weight', 'embeddings.word_embeddings.weight', 'layer_norm.bias'}
+        Unused checkpoint fields: {'model.decoder.embed_tokens.weight', 'model.decoder.final_layer_norm.bias', 'model.decoder.final_layer_norm.weight'}
+
+        """
+
+        unused_checkpoint_fields = set(hf_field_names)
+        remap = {}
+        for field_name in hf_field_names:
+
+            new_field_name = field_name.replace(
+                "model.decoder.embed_tokens", "embeddings.word_embeddings"
+            )
+            new_field_name = new_field_name.replace(
+                "model.decoder.embed_positions", "embeddings.position_embeddings"
+            )
+
+            new_field_name = new_field_name.replace(
+                "model.decoder.final_layer_norm", "layer_norm"
+            )
+
+            new_field_name = new_field_name.replace('model.decoder.', 'encoder.')
+
+
+            new_field_name = new_field_name.replace("self_attn", "self_attention")
+            new_field_name = new_field_name.replace("k_proj", "key")
+            new_field_name = new_field_name.replace("q_proj", "query")
+            new_field_name = new_field_name.replace("v_proj", "value")
+            new_field_name = new_field_name.replace("out_proj", "output")
+            new_field_name = new_field_name.replace(".layers", "")
+            new_field_name = new_field_name.replace(
+                "attention.output.dense", "self_attention.output"
+            )
+            new_field_name = new_field_name.replace("fc1", "ffn.0")
+            new_field_name = new_field_name.replace("fc2", "ffn.2")
+            new_field_name = new_field_name.replace(
+                "final_layer_norm", "output_layer_norm"
+            )
+            if new_field_name in tlm_field_names:
+                tlm_field_names.remove(new_field_name)
+                unused_checkpoint_fields.remove(field_name)
+                remap[new_field_name] = bert_state_dict[field_name]
+
+        tlm.load_state_dict(remap, strict=False)
+        return tlm_field_names, unused_checkpoint_fields
+
+    @classmethod
+    def get_vocab_and_hidden_dims(cls, hf_dict: dict) -> tuple:
+        try:
+            embeddings_weight = hf_dict[
+                [k for k in hf_dict if "decoder.embed_tokens.weight" in k][0]
+            ]
+        except:
+            embeddings_weight = hf_dict[
+                [
+                    k
+                    for k in hf_dict
+                    if "embeddings.word_embeddings.weight" in k
+                ][0]
+            ]
+        return embeddings_weight.shape
+
+    @classmethod
+    def lm_from_pretrained(
+        cls, checkpoint_file_or_dir: str, map_location=None, **kwargs
+    ):
+        if os.path.isdir(checkpoint_file_or_dir):
+            checkpoint = os.path.join(checkpoint_file_or_dir, "pytorch_model.bin")
+        else:
+            checkpoint = checkpoint_file_or_dir
+        hf_dict = torch.load(checkpoint, map_location=map_location)
+        vocab_size, hidden_size = OPTCreator.get_vocab_and_hidden_dims(hf_dict)
+        tlm = OPTTransformerLM(vocab_size, **kwargs)
+        missing, unused = OPTCreator.convert_state_dict(tlm, hf_dict)
+        logging.info(f"Unset params: {missing}")
+        logging.info(f"Unused checkpoint fields: {unused}")
+        return tlm
+
+    @classmethod
+    def pooled_enc_from_pretrained(
+        cls, checkpoint_file_or_dir: str, map_location=None, pool_id=None, **kwargs
+    ):
+        if os.path.isdir(checkpoint_file_or_dir):
+            checkpoint = os.path.join(checkpoint_file_or_dir, "pytorch_model.bin")
+        else:
+            checkpoint = checkpoint_file_or_dir
+        hf_dict = torch.load(checkpoint, map_location=map_location)
+        vocab_size, hidden_size = OPTCreator.get_vocab_and_hidden_dims(hf_dict)
+        enc = OPTTransformerPooledEncoder(vocab_size, pool_id=pool_id, **kwargs)
+        missing, unused = OPTCreator.convert_state_dict(enc, hf_dict)
+        logging.info(f"Unset params: {missing}")
+        logging.info(f"Unused checkpoint fields: {unused}")
+        return enc

From 04299ee0dc375e64d6a1f48912c709f1c4441f30 Mon Sep 17 00:00:00 2001
From: dpressel <dpressel@gmail.com>
Date: Mon, 25 Jul 2022 00:25:35 -0400
Subject: [PATCH 2/4] Initial add

Still need to switch to tokenizers lib for tok-ing
---
 src/mint/examples/opt_completer.py | 92 ++++++++++++++++++++++++++++++
 1 file changed, 92 insertions(+)
 create mode 100644 src/mint/examples/opt_completer.py

diff --git a/src/mint/examples/opt_completer.py b/src/mint/examples/opt_completer.py
new file mode 100644
index 0000000..77a03af
--- /dev/null
+++ b/src/mint/examples/opt_completer.py
@@ -0,0 +1,92 @@
+import logging
+import argparse
+import os
+import torch
+from prompt_toolkit import prompt
+from prompt_toolkit.history import FileHistory
+from mint.opt import OPTCreator
+from transformers import GPT2Tokenizer
+
+logger = logging.getLogger(__file__)
+
+"""An example program where you can provide your GPT model with a priming sequence and have it complete
+"""
+
+
+def main():
+    parser = argparse.ArgumentParser(description="An interactive shell with OPT")
+    parser.add_argument("--model", type=str, required=True, help="Start from a model")
+
+    parser.add_argument(
+        "--query",
+        type=str,
+        help="Optional query.  If you pass this we wont use the repl",
+    )
+    parser.add_argument("--history_file", type=str, default=".gpt_history")
+    parser.add_argument("--max_len", type=int, default=50)
+    parser.add_argument("--sample", action="store_true")
+    parser.add_argument("--temperature", default=1.0, type=float)
+    parser.add_argument(
+        "--device",
+        type=str,
+        default="cuda" if torch.cuda.is_available() else "cpu",
+        help="Device (cuda or cpu)",
+    )
+    args = parser.parse_args()
+    logging.basicConfig(level=logging.INFO)
+
+    tokenizer = GPT2Tokenizer.from_pretrained(args.model)
+
+    model = OPTCreator.lm_from_pretrained(args.model).eval()
+    model.to(args.device)
+
+    def complete(query, sampling, temperature):
+        logger.info("Query: %s", query)
+        inputs = tokenizer.encode(query)
+        print(inputs)
+        print(tokenizer.convert_ids_to_tokens(inputs))
+        outputs = []
+        with torch.no_grad():
+
+            for i in range(args.max_len):
+
+                ids = torch.tensor(inputs, device=args.device)
+                response = model(ids.unsqueeze(0)).squeeze(0)
+                response = response[len(inputs) - 1]
+                if sampling:
+                    sample_dist = torch.softmax(response / temperature, -1)
+                    output = torch.multinomial(sample_dist, num_samples=1)
+                    response = output.squeeze().item()
+                else:
+                    response = response.argmax(-1).item()
+
+                inputs.append(response)
+                outputs.append(response)
+            #outputs = ' '.join(tokenizer.convert_ids_to_tokens(outputs))
+            outputs = tokenizer.decode(outputs)
+            return outputs
+
+    if args.query:
+        print(complete(args.query, args.sample, args.temperature))
+        return
+
+    prompt_name = f"OPT{args.version}>> "
+    history = FileHistory(args.history_file)
+    while True:
+        query = prompt(prompt_name, history=history)
+        query = query.strip()
+        if query == ":quit" or query == "quit":
+            break
+        if query == ":sample":
+            args.sample = True
+            print("Turn sampling mode on")
+            continue
+        if query == ":max":
+            args.sample = False
+            print("Turn sampling mode off")
+            continue
+        print(complete(query, args.sample, args.temperature))
+
+
+if __name__ == "__main__":
+    main()

From b6bb3358f0347a295f9cfba8bb1291b709894193 Mon Sep 17 00:00:00 2001
From: dpressel <dpressel@gmail.com>
Date: Tue, 26 Jul 2022 16:09:26 -0400
Subject: [PATCH 3/4] Use HF tokenizers directly

---
 README.md                          |  1 +
 src/mint/examples/opt_completer.py | 38 +++++++++++++++++++++++-------
 2 files changed, 30 insertions(+), 9 deletions(-)

diff --git a/README.md b/README.md
index 41e709a..2f7b47b 100644
--- a/README.md
+++ b/README.md
@@ -29,6 +29,7 @@ Minimal PyTorch implementation of common Transformer architectures.  Currently i
 - Decoder Only
   - [GPT](https://s3-us-west-2.amazonaws.com/openai-assets/research-covers/language-unsupervised/language_understanding_paper.pdf)
   - [GPT2](https://d4mucfpksywv.cloudfront.net/better-language-models/language-models.pdf)
+  - [OPT](https://arxiv.org/pdf/2205.01068.pdf)
 - Encoder-Decoder
   - [BART](https://arxiv.org/pdf/1910.13461v1.pdf)
   - [T5](https://arxiv.org/pdf/1910.10683.pdf)
diff --git a/src/mint/examples/opt_completer.py b/src/mint/examples/opt_completer.py
index 77a03af..73c9430 100644
--- a/src/mint/examples/opt_completer.py
+++ b/src/mint/examples/opt_completer.py
@@ -5,18 +5,37 @@
 from prompt_toolkit import prompt
 from prompt_toolkit.history import FileHistory
 from mint.opt import OPTCreator
-from transformers import GPT2Tokenizer
+from tokenizers import Tokenizer
 
 logger = logging.getLogger(__file__)
 
-"""An example program where you can provide your GPT model with a priming sequence and have it complete
+"""An example program where you can provide your OPT model with a priming sequence and have it complete
+
+The HF Tokenizers compatible tokenizer.json is available from:
+ 
+https://www.dropbox.com/s/ut8qj4nynhkq4cd/tokenizer.json?dl=1
+
+It was processed using GPT2's tokenizer.json as a template, and replacing the "merges" field with the contents of
+"merges.txt" and replacing the "vocab" field with the contents of "vocab.json", and finally, by setting the 
+postprocessor as follows:
+
+.. code-block:: python
+    tokenizer.post_processor = TemplateProcessing(
+            single="</s> $A",
+            special_tokens=[
+                ("</s>", 1),
+            ],
+        )
+
 """
 
 
 def main():
     parser = argparse.ArgumentParser(description="An interactive shell with OPT")
     parser.add_argument("--model", type=str, required=True, help="Start from a model")
-
+    parser.add_argument(
+        "--tok_file", type=str, required=True, help="Path to tokenizer.json file"
+    )
     parser.add_argument(
         "--query",
         type=str,
@@ -32,19 +51,20 @@ def main():
         default="cuda" if torch.cuda.is_available() else "cpu",
         help="Device (cuda or cpu)",
     )
+
     args = parser.parse_args()
     logging.basicConfig(level=logging.INFO)
-
-    tokenizer = GPT2Tokenizer.from_pretrained(args.model)
-
+    if os.path.isdir(args.tok_file):
+        args.tok_file = os.path.join(args.tok_file, "tokenizer.json")
+    tokenizer = Tokenizer.from_file(args.tok_file)
     model = OPTCreator.lm_from_pretrained(args.model).eval()
     model.to(args.device)
 
     def complete(query, sampling, temperature):
         logger.info("Query: %s", query)
-        inputs = tokenizer.encode(query)
-        print(inputs)
-        print(tokenizer.convert_ids_to_tokens(inputs))
+        tokenized_input = tokenizer.encode(query)
+        logger.info("Priming Sequence: %s", " ".join(tokenized_input.tokens))
+        inputs = tokenized_input.ids
         outputs = []
         with torch.no_grad():
 

From 846359af1d6c31aec23a61c08c7edafbcee307b9 Mon Sep 17 00:00:00 2001
From: dpressel <dpressel@gmail.com>
Date: Tue, 26 Jul 2022 16:10:09 -0400
Subject: [PATCH 4/4] update README

---
 README.md | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/README.md b/README.md
index 2f7b47b..23e6eb2 100644
--- a/README.md
+++ b/README.md
@@ -29,7 +29,7 @@ Minimal PyTorch implementation of common Transformer architectures.  Currently i
 - Decoder Only
   - [GPT](https://s3-us-west-2.amazonaws.com/openai-assets/research-covers/language-unsupervised/language_understanding_paper.pdf)
   - [GPT2](https://d4mucfpksywv.cloudfront.net/better-language-models/language-models.pdf)
-  - [OPT](https://arxiv.org/pdf/2205.01068.pdf)
+  - [OPT](https://arxiv.org/pdf/2205.01068.pdfgit)
 - Encoder-Decoder
   - [BART](https://arxiv.org/pdf/1910.13461v1.pdf)
   - [T5](https://arxiv.org/pdf/1910.10683.pdf)