From 409b1747e380f1665c48c196622ba82106533c1a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 21 Dec 2024 00:50:13 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- examples/llm/g_retriever.py | 31 ++++++++++++++++--------------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/examples/llm/g_retriever.py b/examples/llm/g_retriever.py index f3ec1e40c4f8..04d9adfb1c78 100644 --- a/examples/llm/g_retriever.py +++ b/examples/llm/g_retriever.py @@ -8,7 +8,6 @@ import pandas as pd import torch -from torch import Tensor from torch.nn.utils import clip_grad_norm_ from tqdm import tqdm @@ -18,10 +17,10 @@ from torch_geometric.nn.models import GAT, GRetriever from torch_geometric.nn.nlp import LLM + # Define a function to compute evaluation metrics def compute_metrics(eval_output): - """ - Compute evaluation metrics (Hit, Precision, Recall, F1) from the output of the inference step. + """Compute evaluation metrics (Hit, Precision, Recall, F1) from the output of the inference step. Args: eval_output (list): List of dictionaries containing prediction and label information. @@ -74,10 +73,10 @@ def compute_metrics(eval_output): print(f'Recall: {recall:.4f}') print(f'F1: {f1:.4f}') + # Define a function to save model parameters to a file def save_params_dict(model, save_path): - """ - Save model parameters to a file. + """Save model parameters to a file. Args: model (torch.nn.Module): Model to save. @@ -93,10 +92,10 @@ def save_params_dict(model, save_path): del state_dict[k] # Delete parameters that do not require gradient torch.save(state_dict, save_path) + # Define a function to load model parameters from a file def load_params_dict(model, save_path): - """ - Load model parameters from a file. + """Load model parameters from a file. Args: model (torch.nn.Module): Model to load. @@ -106,10 +105,10 @@ def load_params_dict(model, save_path): model.load_state_dict(state_dict) return model + # Define a function to compute the loss for a given model and batch def get_loss(model, batch, model_save_name): - """ - Compute the loss for a given model and batch. + """Compute the loss for a given model and batch. Args: model (torch.nn.Module): Model to compute loss for. @@ -125,10 +124,10 @@ def get_loss(model, batch, model_save_name): return model(batch.question, batch.x, batch.edge_index, batch.batch, batch.label, batch.edge_attr, batch.desc) + # Define a function to perform inference for a given model and batch def inference_step(model, batch, model_save_name): - """ - Perform inference for a given model and batch. + """Perform inference for a given model and batch. Args: model (torch.nn.Module): Model to perform inference with. @@ -144,6 +143,7 @@ def inference_step(model, batch, model_save_name): return model.inference(batch.question, batch.x, batch.edge_index, batch.batch, batch.edge_attr, batch.desc) + # Define the training loop def train( num_epochs, @@ -155,8 +155,7 @@ def train( checkpointing=False, tiny_llama=False, ): - """ - Train the model for a specified number of epochs. + """Train the model for a specified number of epochs. Args: num_epochs (int): Number of epochs to train for. @@ -168,6 +167,7 @@ def train( checkpointing (bool): Whether to save model checkpoints. Default: False. tiny_llama (bool): Whether to use the tiny LLaMA model. Default: False. """ + # Define the learning rate schedule def adjust_learning_rate(param_group, LR, epoch): min_lr = 5e-6 @@ -276,7 +276,7 @@ def adjust_learning_rate(param_group, LR, epoch): # Update learning rate schedule if (step + 1) % grad_steps == 0: lr = adjust_learning_rate(optimizer.param_groups[0], lr, - step / len(train_loader) + epoch) + step / len(train_loader) + epoch) # Update optimizer optimizer.step() @@ -347,6 +347,7 @@ def adjust_learning_rate(param_group, LR, epoch): save_params_dict(model, f'{model_save_name}.pt') torch.save(eval_output, f'{model_save_name}_eval_outs.pt') + if __name__ == '__main__': # Parse command-line arguments parser = argparse.ArgumentParser() @@ -376,4 +377,4 @@ def adjust_learning_rate(param_group, LR, epoch): ) # Print total time - print(f"Total Time: {time.time() - start_time:2f}s") \ No newline at end of file + print(f"Total Time: {time.time() - start_time:2f}s")