Skip to content

Commit

Permalink
Merge branch 'AI-cleanup' of https://github.com/pyg-team/pytorch_geom…
Browse files Browse the repository at this point in the history
…etric into AI-cleanup
  • Loading branch information
riship committed Dec 21, 2024
2 parents ef01be3 + 409b174 commit 9c3e365
Showing 1 changed file with 16 additions and 15 deletions.
31 changes: 16 additions & 15 deletions examples/llm/g_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

import pandas as pd
import torch
from torch import Tensor
from torch.nn.utils import clip_grad_norm_
from tqdm import tqdm

Expand All @@ -27,10 +26,10 @@
from torch_geometric.nn.models import GAT, GRetriever
from torch_geometric.nn.nlp import LLM


# Define a function to compute evaluation metrics
def compute_metrics(eval_output):
"""
Compute evaluation metrics (Hit, Precision, Recall, F1) from the output of the inference step.
"""Compute evaluation metrics (Hit, Precision, Recall, F1) from the output of the inference step.
Args:
eval_output (list): List of dictionaries containing prediction and label information.
Expand Down Expand Up @@ -83,10 +82,10 @@ def compute_metrics(eval_output):
print(f'Recall: {recall:.4f}')
print(f'F1: {f1:.4f}')


# Define a function to save model parameters to a file
def save_params_dict(model, save_path):
"""
Save model parameters to a file.
"""Save model parameters to a file.
Args:
model (torch.nn.Module): Model to save.
Expand All @@ -102,10 +101,10 @@ def save_params_dict(model, save_path):
del state_dict[k] # Delete parameters that do not require gradient
torch.save(state_dict, save_path)


# Define a function to load model parameters from a file
def load_params_dict(model, save_path):
"""
Load model parameters from a file.
"""Load model parameters from a file.
Args:
model (torch.nn.Module): Model to load.
Expand All @@ -115,10 +114,10 @@ def load_params_dict(model, save_path):
model.load_state_dict(state_dict)
return model


# Define a function to compute the loss for a given model and batch
def get_loss(model, batch, model_save_name):
"""
Compute the loss for a given model and batch.
"""Compute the loss for a given model and batch.
Args:
model (torch.nn.Module): Model to compute loss for.
Expand All @@ -134,10 +133,10 @@ def get_loss(model, batch, model_save_name):
return model(batch.question, batch.x, batch.edge_index, batch.batch,
batch.label, batch.edge_attr, batch.desc)


# Define a function to perform inference for a given model and batch
def inference_step(model, batch, model_save_name):
"""
Perform inference for a given model and batch.
"""Perform inference for a given model and batch.
Args:
model (torch.nn.Module): Model to perform inference with.
Expand All @@ -153,6 +152,7 @@ def inference_step(model, batch, model_save_name):
return model.inference(batch.question, batch.x, batch.edge_index,
batch.batch, batch.edge_attr, batch.desc)


# Define the training loop
def train(
num_epochs,
Expand All @@ -164,8 +164,7 @@ def train(
checkpointing=False,
tiny_llama=False,
):
"""
Train the model for a specified number of epochs.
"""Train the model for a specified number of epochs.
Args:
num_epochs (int): Number of epochs to train for.
Expand All @@ -177,6 +176,7 @@ def train(
checkpointing (bool): Whether to save model checkpoints. Default: False.
tiny_llama (bool): Whether to use the tiny LLaMA model. Default: False.
"""

# Define the learning rate schedule
def adjust_learning_rate(param_group, LR, epoch):
min_lr = 5e-6
Expand Down Expand Up @@ -285,7 +285,7 @@ def adjust_learning_rate(param_group, LR, epoch):
# Update learning rate schedule
if (step + 1) % grad_steps == 0:
lr = adjust_learning_rate(optimizer.param_groups[0], lr,
step / len(train_loader) + epoch)
step / len(train_loader) + epoch)

# Update optimizer
optimizer.step()
Expand Down Expand Up @@ -356,6 +356,7 @@ def adjust_learning_rate(param_group, LR, epoch):
save_params_dict(model, f'{model_save_name}.pt')
torch.save(eval_output, f'{model_save_name}_eval_outs.pt')


if __name__ == '__main__':
# Parse command-line arguments
parser = argparse.ArgumentParser()
Expand Down Expand Up @@ -385,4 +386,4 @@ def adjust_learning_rate(param_group, LR, epoch):
)

# Print total time
print(f"Total Time: {time.time() - start_time:2f}s")
print(f"Total Time: {time.time() - start_time:2f}s")

0 comments on commit 9c3e365

Please sign in to comment.