Skip to content

Commit

Permalink
Using AI tools to improve commenting of base G-retriever example
Browse files Browse the repository at this point in the history
  • Loading branch information
riship committed Dec 21, 2024
2 parents 0b8d565 + 6429b92 commit 05b8e03
Showing 1 changed file with 24 additions and 29 deletions.
53 changes: 24 additions & 29 deletions examples/llm/g_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,14 @@
from torch.nn.utils import clip_grad_norm_
from tqdm import tqdm

from torch_geometric import seed_everything
from torch_geometric.datasets import WebQSPDataset
from torch_geometric.loader import DataLoader
from torch_geometric.nn.models import GAT, GRetriever
from torch_geometric.nn.nlp import LLM


def compute_metrics(eval_output):
"""
Compute evaluation metrics (hit, precision, recall, F1) from prediction output.
"""Compute evaluation metrics (hit, precision, recall, F1) from prediction output.
Parameters:
eval_output (list): List of dictionaries containing prediction output.
Expand All @@ -43,7 +42,8 @@ def compute_metrics(eval_output):
df = pd.concat([pd.DataFrame(d) for d in eval_output])

# Initialize lists to store metrics
all_hit = [] # List of boolean values indicating whether prediction matches label
all_hit = [
] # List of boolean values indicating whether prediction matches label
all_precision = [] # List of precision values
all_recall = [] # List of recall values
all_f1 = [] # List of F1 values
Expand Down Expand Up @@ -96,45 +96,43 @@ def compute_metrics(eval_output):


def save_params_dict(model, save_path):
"""
Saves a model's parameters to a file while excluding non-trainable weights.
"""Saves a model's parameters to a file while excluding non-trainable weights.
Args:
model (torch.nn.Module): The model to save parameters from.
save_path (str): The path to save the parameters to.
"""
# Get the model's state dictionary, which contains all its parameters
state_dict = model.state_dict()

# Create a dictionary mapping parameter names to their requires_grad status
param_grad_dict = {
k: v.requires_grad
for (k, v) in model.named_parameters()
}

# Remove non-trainable parameters from the state dictionary
for k in list(state_dict.keys()):
if k in param_grad_dict.keys() and not param_grad_dict[k]:
del state_dict[k] # Delete parameters that do not require gradient

# Save the filtered state dictionary to the specified path
torch.save(state_dict, save_path)


def load_params_dict(model, save_path):
# Load the saved model parameters from the specified file path
state_dict = torch.load(save_path)

# Update the model's parameters with the loaded state dictionary
model.load_state_dict(state_dict)

# Return the model with updated parameters
return model


def get_loss(model, batch, model_save_name: str) -> Tensor:
"""
Compute the loss for a given model and batch of data.
"""Compute the loss for a given model and batch of data.
Args:
model: The model to compute the loss for.
Expand All @@ -160,9 +158,9 @@ def get_loss(model, batch, model_save_name: str) -> Tensor:
)



def inference_step(model, batch, model_save_name):
"""
Performs inference on a given batch of data using the provided model.
"""Performs inference on a given batch of data using the provided model.
Args:
model (nn.Module): The model to use for inference.
Expand All @@ -176,7 +174,7 @@ def inference_step(model, batch, model_save_name):
if model_save_name == 'llm':
# Perform inference on the question and textual graph description
return model.inference(batch.question, batch.desc)
else: # (GNN+LLM)
else: # (GNN+LLM)
return model.inference(
batch.question,
batch.x, # node features
Expand All @@ -188,17 +186,16 @@ def inference_step(model, batch, model_save_name):


def train(
num_epochs, # Total number of training epochs
hidden_channels, # Number of hidden channels in GNN
num_gnn_layers, # Number of GNN layers
batch_size, # Training batch size
eval_batch_size, # Evaluation batch size
lr, # Initial learning rate
checkpointing=False, # Whether to checkpoint model
tiny_llama=False, # Whether to use tiny LLaMA model
num_epochs, # Total number of training epochs
hidden_channels, # Number of hidden channels in GNN
num_gnn_layers, # Number of GNN layers
batch_size, # Training batch size
eval_batch_size, # Evaluation batch size
lr, # Initial learning rate
checkpointing=False, # Whether to checkpoint model
tiny_llama=False, # Whether to use tiny LLaMA model
):
"""
Train a GNN+LLM model on WebQSP dataset.
"""Train a GNN+LLM model on WebQSP dataset.
Args:
num_epochs (int): Total number of training epochs.
Expand All @@ -213,10 +210,8 @@ def train(
Returns:
None
"""

def adjust_learning_rate(param_group, LR, epoch):
"""
Decay learning rate with half-cycle cosine after warmup.
"""Decay learning rate with half-cycle cosine after warmup.
Args:
param_group (dict): Parameter group.
Expand Down

0 comments on commit 05b8e03

Please sign in to comment.