Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Dec 21, 2024
1 parent cf28bc3 commit 409b174
Showing 1 changed file with 16 additions and 15 deletions.
31 changes: 16 additions & 15 deletions examples/llm/g_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

import pandas as pd
import torch
from torch import Tensor
from torch.nn.utils import clip_grad_norm_
from tqdm import tqdm

Expand All @@ -18,10 +17,10 @@
from torch_geometric.nn.models import GAT, GRetriever
from torch_geometric.nn.nlp import LLM


# Define a function to compute evaluation metrics
def compute_metrics(eval_output):
"""
Compute evaluation metrics (Hit, Precision, Recall, F1) from the output of the inference step.
"""Compute evaluation metrics (Hit, Precision, Recall, F1) from the output of the inference step.
Args:
eval_output (list): List of dictionaries containing prediction and label information.
Expand Down Expand Up @@ -74,10 +73,10 @@ def compute_metrics(eval_output):
print(f'Recall: {recall:.4f}')
print(f'F1: {f1:.4f}')


# Define a function to save model parameters to a file
def save_params_dict(model, save_path):
"""
Save model parameters to a file.
"""Save model parameters to a file.
Args:
model (torch.nn.Module): Model to save.
Expand All @@ -93,10 +92,10 @@ def save_params_dict(model, save_path):
del state_dict[k] # Delete parameters that do not require gradient
torch.save(state_dict, save_path)


# Define a function to load model parameters from a file
def load_params_dict(model, save_path):
"""
Load model parameters from a file.
"""Load model parameters from a file.
Args:
model (torch.nn.Module): Model to load.
Expand All @@ -106,10 +105,10 @@ def load_params_dict(model, save_path):
model.load_state_dict(state_dict)
return model


# Define a function to compute the loss for a given model and batch
def get_loss(model, batch, model_save_name):
"""
Compute the loss for a given model and batch.
"""Compute the loss for a given model and batch.
Args:
model (torch.nn.Module): Model to compute loss for.
Expand All @@ -125,10 +124,10 @@ def get_loss(model, batch, model_save_name):
return model(batch.question, batch.x, batch.edge_index, batch.batch,
batch.label, batch.edge_attr, batch.desc)


# Define a function to perform inference for a given model and batch
def inference_step(model, batch, model_save_name):
"""
Perform inference for a given model and batch.
"""Perform inference for a given model and batch.
Args:
model (torch.nn.Module): Model to perform inference with.
Expand All @@ -144,6 +143,7 @@ def inference_step(model, batch, model_save_name):
return model.inference(batch.question, batch.x, batch.edge_index,
batch.batch, batch.edge_attr, batch.desc)


# Define the training loop
def train(
num_epochs,
Expand All @@ -155,8 +155,7 @@ def train(
checkpointing=False,
tiny_llama=False,
):
"""
Train the model for a specified number of epochs.
"""Train the model for a specified number of epochs.
Args:
num_epochs (int): Number of epochs to train for.
Expand All @@ -168,6 +167,7 @@ def train(
checkpointing (bool): Whether to save model checkpoints. Default: False.
tiny_llama (bool): Whether to use the tiny LLaMA model. Default: False.
"""

# Define the learning rate schedule
def adjust_learning_rate(param_group, LR, epoch):
min_lr = 5e-6
Expand Down Expand Up @@ -276,7 +276,7 @@ def adjust_learning_rate(param_group, LR, epoch):
# Update learning rate schedule
if (step + 1) % grad_steps == 0:
lr = adjust_learning_rate(optimizer.param_groups[0], lr,
step / len(train_loader) + epoch)
step / len(train_loader) + epoch)

# Update optimizer
optimizer.step()
Expand Down Expand Up @@ -347,6 +347,7 @@ def adjust_learning_rate(param_group, LR, epoch):
save_params_dict(model, f'{model_save_name}.pt')
torch.save(eval_output, f'{model_save_name}_eval_outs.pt')


if __name__ == '__main__':
# Parse command-line arguments
parser = argparse.ArgumentParser()
Expand Down Expand Up @@ -376,4 +377,4 @@ def adjust_learning_rate(param_group, LR, epoch):
)

# Print total time
print(f"Total Time: {time.time() - start_time:2f}s")
print(f"Total Time: {time.time() - start_time:2f}s")

0 comments on commit 409b174

Please sign in to comment.