diff --git a/README.md b/README.md index 41e709a..23e6eb2 100644 --- a/README.md +++ b/README.md @@ -29,6 +29,7 @@ Minimal PyTorch implementation of common Transformer architectures. Currently i - Decoder Only - [GPT](https://s3-us-west-2.amazonaws.com/openai-assets/research-covers/language-unsupervised/language_understanding_paper.pdf) - [GPT2](https://d4mucfpksywv.cloudfront.net/better-language-models/language-models.pdf) + - [OPT](https://arxiv.org/pdf/2205.01068.pdfgit) - Encoder-Decoder - [BART](https://arxiv.org/pdf/1910.13461v1.pdf) - [T5](https://arxiv.org/pdf/1910.10683.pdf) diff --git a/src/mint/examples/opt_completer.py b/src/mint/examples/opt_completer.py new file mode 100644 index 0000000..73c9430 --- /dev/null +++ b/src/mint/examples/opt_completer.py @@ -0,0 +1,112 @@ +import logging +import argparse +import os +import torch +from prompt_toolkit import prompt +from prompt_toolkit.history import FileHistory +from mint.opt import OPTCreator +from tokenizers import Tokenizer + +logger = logging.getLogger(__file__) + +"""An example program where you can provide your OPT model with a priming sequence and have it complete + +The HF Tokenizers compatible tokenizer.json is available from: + +https://www.dropbox.com/s/ut8qj4nynhkq4cd/tokenizer.json?dl=1 + +It was processed using GPT2's tokenizer.json as a template, and replacing the "merges" field with the contents of +"merges.txt" and replacing the "vocab" field with the contents of "vocab.json", and finally, by setting the +postprocessor as follows: + +.. code-block:: python + tokenizer.post_processor = TemplateProcessing( + single=" $A", + special_tokens=[ + ("", 1), + ], + ) + +""" + + +def main(): + parser = argparse.ArgumentParser(description="An interactive shell with OPT") + parser.add_argument("--model", type=str, required=True, help="Start from a model") + parser.add_argument( + "--tok_file", type=str, required=True, help="Path to tokenizer.json file" + ) + parser.add_argument( + "--query", + type=str, + help="Optional query. If you pass this we wont use the repl", + ) + parser.add_argument("--history_file", type=str, default=".gpt_history") + parser.add_argument("--max_len", type=int, default=50) + parser.add_argument("--sample", action="store_true") + parser.add_argument("--temperature", default=1.0, type=float) + parser.add_argument( + "--device", + type=str, + default="cuda" if torch.cuda.is_available() else "cpu", + help="Device (cuda or cpu)", + ) + + args = parser.parse_args() + logging.basicConfig(level=logging.INFO) + if os.path.isdir(args.tok_file): + args.tok_file = os.path.join(args.tok_file, "tokenizer.json") + tokenizer = Tokenizer.from_file(args.tok_file) + model = OPTCreator.lm_from_pretrained(args.model).eval() + model.to(args.device) + + def complete(query, sampling, temperature): + logger.info("Query: %s", query) + tokenized_input = tokenizer.encode(query) + logger.info("Priming Sequence: %s", " ".join(tokenized_input.tokens)) + inputs = tokenized_input.ids + outputs = [] + with torch.no_grad(): + + for i in range(args.max_len): + + ids = torch.tensor(inputs, device=args.device) + response = model(ids.unsqueeze(0)).squeeze(0) + response = response[len(inputs) - 1] + if sampling: + sample_dist = torch.softmax(response / temperature, -1) + output = torch.multinomial(sample_dist, num_samples=1) + response = output.squeeze().item() + else: + response = response.argmax(-1).item() + + inputs.append(response) + outputs.append(response) + #outputs = ' '.join(tokenizer.convert_ids_to_tokens(outputs)) + outputs = tokenizer.decode(outputs) + return outputs + + if args.query: + print(complete(args.query, args.sample, args.temperature)) + return + + prompt_name = f"OPT{args.version}>> " + history = FileHistory(args.history_file) + while True: + query = prompt(prompt_name, history=history) + query = query.strip() + if query == ":quit" or query == "quit": + break + if query == ":sample": + args.sample = True + print("Turn sampling mode on") + continue + if query == ":max": + args.sample = False + print("Turn sampling mode off") + continue + print(complete(query, args.sample, args.temperature)) + + +if __name__ == "__main__": + main() diff --git a/src/mint/opt.py b/src/mint/opt.py new file mode 100644 index 0000000..4193b6c --- /dev/null +++ b/src/mint/opt.py @@ -0,0 +1,356 @@ +import torch +import torch.nn as nn +import os +from typing import Optional +from mint.common import WeightTiedVocabProjection +from mint.preln import PreLayerNormTransformerEncoder +from mint.postln import TransformerEncoder +import logging + +logger = logging.getLogger("mint") + +class OPTLearnedPositionalEmbedding(nn.Module): + """Learned positional embeddings for OPT + + The embeddings are a combination of 2 inputs, word embeddings and positional embeddings + The word embeddings is a learned vector that uses the word one-hots to convert to a dense representation. + Each of these embeddings are added together in the forward + """ + + OPT_POS_OFFSET = 2 + + def __init__( + self, + vocab_dim: int, + hidden_dim: int = 768, + padding_idx: int = 0, + max_seq_len: int = 2048, + ): + super().__init__() + self.word_embeddings = nn.Embedding(vocab_dim, hidden_dim, padding_idx) + self.position_embeddings = nn.Embedding( + max_seq_len + OPTLearnedPositionalEmbedding.OPT_POS_OFFSET, hidden_dim + ) + + def forward( + self, x: torch.Tensor, token_type: Optional[torch.Tensor] = None + ) -> torch.Tensor: + """Takes a tensor of shape `[B, T]` and an optional `token_type` of same shape + + :param x: A tensor of word one-hots, shape `[B, T]` + :param token_type: Ignored for BART! + :return: The sum of the positional and word embeddings + """ + embed = self.word_embeddings(x) + + position = self.position_embeddings( + torch.arange(x.shape[-1], dtype=x.dtype).to(x.device) + + OPTLearnedPositionalEmbedding.OPT_POS_OFFSET + ).unsqueeze(0) + + return embed + position + + @property + def weight(self): + """Access word_embeddings weights + + :return: The word_embeddings weights + """ + return self.word_embeddings.weight + + +class OPTTransformerLM(PreLayerNormTransformerEncoder): + """OPT LM predicts tokens from left-to-right, with pre-layer-norm encoders + """ + + def __init__( + self, + vocab_size: int, + padding_idx: int = 0, + hidden_size: int = 768, + num_heads: int = 12, + num_layers: int = 12, + dropout: float = 0.1, + layer_norm_eps: float = 1e-12, + activation: nn.Module = nn.ReLU(), + feed_forward_size: Optional[int] = None, + max_seq_len: int = 2048, + **kwargs, + ): + super().__init__( + OPTLearnedPositionalEmbedding, + vocab_size, + padding_idx, + hidden_size, + num_heads, + num_layers, + dropout, + layer_norm_eps, + activation, + feed_forward_size, + max_seq_len, + ) + self.activation = activation + + self.register_buffer( + "causal_mask", + torch.tril( + torch.ones( + ( + max_seq_len, + max_seq_len, + ), + dtype=torch.uint8, + ) + ) + .unsqueeze(0) + .unsqueeze(0), + ) + + self.output_layer = WeightTiedVocabProjection(self.embeddings.word_embeddings) + self.apply(self.init_layer_weights) + + def create_loss(self): + return nn.CrossEntropyLoss(ignore_index=0) + + def forward( + self, + x: torch.Tensor, + mask: Optional[torch.Tensor] = None, + token_type: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Apply the encoder from the parent, followed by penultimate and output projection + + :param x: A one-hot (long) tensor of shape `[B, T]` + :param mask: An optional mask to take in for attention + :param token_type: An optional tensor of 0 or 1, shape `[B, T]` + :return: + """ + input_mask = self.causal_mask[:, :, : x.shape[1], : x.shape[1]] + if mask is not None: + input_mask = mask & input_mask.to(dtype=torch.bool) + + y = super().forward(x, input_mask) + y = self.output_layer(y) + return y + + +class OPTTransformerPooledEncoder(PreLayerNormTransformerEncoder): + """Use our Transformer encoder with a pooling head. + + We will use this model for classification + """ + + def __init__( + self, + vocab_size: int, + padding_idx: int = 0, + hidden_size: int = 768, + num_heads: int = 12, + num_layers: int = 12, + dropout: float = 0.1, + layer_norm_eps: float = 1e-12, + activation: nn.Module = nn.ReLU(), + feed_forward_size: Optional[int] = None, + output: Optional[nn.Module] = None, + max_seq_len: int = 2048, + pool_id: Optional[int] = None, + **kwargs, + ): + """Set up initialization for a (post-layer-norm) Transformer with pooling output. Defaults to bert-base settings + + :param vocab_size: The size of the input vocabulary + :param padding_idx: The padding index, defaults to 0 + :param hidden_size: The number of hidden units + :param num_heads: The number of heads for multi-headed attn. Should divide evenly into hidden_size + :param num_layers: The number of transformer layers (MHA+FFN) in the architecture + :param dropout: The value to apply for dropout + :param layer_norm_eps: The noising term for layer norm + :param activation: The activation function to use throughout + :param feed_forward_size: An optional value to set for the FFN MLP output size, defaults to 4*hidden_size + :param output: An optional projection layer to apply at the end + :param max_seq_len: The maximum seq len, for GPT2 this should be 1024 + :param pool_id: An optional integer value to use for the pooling token. If not set, we use mean pooling + """ + super().__init__( + OPTLearnedPositionalEmbedding, + vocab_size, + padding_idx, + hidden_size, + num_heads, + num_layers, + dropout, + layer_norm_eps, + activation, + feed_forward_size, + max_seq_len, + ) + + self.pooling = self.mean_pool if pool_id is None else self.pool_by_id + self.pool_id = pool_id + + self.output = output if output else nn.Identity() + self.register_buffer( + "causal_mask", + torch.tril( + torch.ones( + ( + max_seq_len, + max_seq_len, + ), + dtype=torch.uint8, + ) + ) + .unsqueeze(0) + .unsqueeze(0), + ) + + self.apply(self.init_layer_weights) + + def forward( + self, + x: torch.Tensor, + mask: Optional[torch.Tensor] = None, + token_type: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + + :param x: A one-hot (long) tensor of shape `[B, T]` + :param mask: An optional mask to take in for attention + :param token_type: + :return: + """ + input_mask = self.causal_mask[:, :, : x.shape[1], : x.shape[1]] + if mask is not None: + input_mask = mask & input_mask.to(dtype=torch.bool) + + y = self.embeddings(x, token_type) + for t in self.encoder: + y = t(y, input_mask) + + y = self.pooling(x, y) + return self.output(y) + + def pool_by_id(self, inputs, embeddings): + return embeddings[inputs == self.pool_id] + + def mean_pool(self, inputs, embeddings): + mask = inputs != self.padding_idx + seq_lengths = mask.sum(1).float() + embeddings = embeddings.masked_fill(mask.unsqueeze(-1) == False, 0.0) + return embeddings.sum(1) / seq_lengths.unsqueeze(-1) + + +class OPTCreator: + @classmethod + def convert_state_dict(cls, tlm, bert_state_dict): + """Convert the state dict to TFS compatible names + + The encoder token embeddings (AKA word_embeddings) are shared with the decoder token embeddings, and + in the HF implementation, this is done via `self.shared` so all 3 items are in the original checkpoint, + and we only need one of them. We have tied these together by assignment already, so loading the encoder's + word embeddings updates the decoder word embeddings too + + Note that the positional embeddings are different for encoder and decoder, so these are not shared and both + are loaded + + :param tlm: + :param bert_state_dict: + :return: + """ + tlm_field_names = set(k for k in tlm.state_dict().keys()) + hf_field_names = bert_state_dict.keys() + + """ + Unset params: {'layer_norm.weight', 'embeddings.word_embeddings.weight', 'layer_norm.bias'} + Unused checkpoint fields: {'model.decoder.embed_tokens.weight', 'model.decoder.final_layer_norm.bias', 'model.decoder.final_layer_norm.weight'} + + """ + + unused_checkpoint_fields = set(hf_field_names) + remap = {} + for field_name in hf_field_names: + + new_field_name = field_name.replace( + "model.decoder.embed_tokens", "embeddings.word_embeddings" + ) + new_field_name = new_field_name.replace( + "model.decoder.embed_positions", "embeddings.position_embeddings" + ) + + new_field_name = new_field_name.replace( + "model.decoder.final_layer_norm", "layer_norm" + ) + + new_field_name = new_field_name.replace('model.decoder.', 'encoder.') + + + new_field_name = new_field_name.replace("self_attn", "self_attention") + new_field_name = new_field_name.replace("k_proj", "key") + new_field_name = new_field_name.replace("q_proj", "query") + new_field_name = new_field_name.replace("v_proj", "value") + new_field_name = new_field_name.replace("out_proj", "output") + new_field_name = new_field_name.replace(".layers", "") + new_field_name = new_field_name.replace( + "attention.output.dense", "self_attention.output" + ) + new_field_name = new_field_name.replace("fc1", "ffn.0") + new_field_name = new_field_name.replace("fc2", "ffn.2") + new_field_name = new_field_name.replace( + "final_layer_norm", "output_layer_norm" + ) + if new_field_name in tlm_field_names: + tlm_field_names.remove(new_field_name) + unused_checkpoint_fields.remove(field_name) + remap[new_field_name] = bert_state_dict[field_name] + + tlm.load_state_dict(remap, strict=False) + return tlm_field_names, unused_checkpoint_fields + + @classmethod + def get_vocab_and_hidden_dims(cls, hf_dict: dict) -> tuple: + try: + embeddings_weight = hf_dict[ + [k for k in hf_dict if "decoder.embed_tokens.weight" in k][0] + ] + except: + embeddings_weight = hf_dict[ + [ + k + for k in hf_dict + if "embeddings.word_embeddings.weight" in k + ][0] + ] + return embeddings_weight.shape + + @classmethod + def lm_from_pretrained( + cls, checkpoint_file_or_dir: str, map_location=None, **kwargs + ): + if os.path.isdir(checkpoint_file_or_dir): + checkpoint = os.path.join(checkpoint_file_or_dir, "pytorch_model.bin") + else: + checkpoint = checkpoint_file_or_dir + hf_dict = torch.load(checkpoint, map_location=map_location) + vocab_size, hidden_size = OPTCreator.get_vocab_and_hidden_dims(hf_dict) + tlm = OPTTransformerLM(vocab_size, **kwargs) + missing, unused = OPTCreator.convert_state_dict(tlm, hf_dict) + logging.info(f"Unset params: {missing}") + logging.info(f"Unused checkpoint fields: {unused}") + return tlm + + @classmethod + def pooled_enc_from_pretrained( + cls, checkpoint_file_or_dir: str, map_location=None, pool_id=None, **kwargs + ): + if os.path.isdir(checkpoint_file_or_dir): + checkpoint = os.path.join(checkpoint_file_or_dir, "pytorch_model.bin") + else: + checkpoint = checkpoint_file_or_dir + hf_dict = torch.load(checkpoint, map_location=map_location) + vocab_size, hidden_size = OPTCreator.get_vocab_and_hidden_dims(hf_dict) + enc = OPTTransformerPooledEncoder(vocab_size, pool_id=pool_id, **kwargs) + missing, unused = OPTCreator.convert_state_dict(enc, hf_dict) + logging.info(f"Unset params: {missing}") + logging.info(f"Unused checkpoint fields: {unused}") + return enc