diff --git a/examples/llm/g_retriever.py b/examples/llm/g_retriever.py index bdcea8831e0a..18fc4bc7eeb3 100644 --- a/examples/llm/g_retriever.py +++ b/examples/llm/g_retriever.py @@ -10,7 +10,6 @@ Example repo for integration with Neo4j Graph DB: https://github.com/neo4j-product-examples/neo4j-gnn-llm-example """ -# Import necessary libraries and modules import argparse import gc import math @@ -20,6 +19,7 @@ import pandas as pd import torch +from torch import Tensor from torch.nn.utils import clip_grad_norm_ from tqdm import tqdm @@ -29,160 +29,196 @@ from torch_geometric.nn.models import GAT, GRetriever from torch_geometric.nn.nlp import LLM - -# Define a function to compute evaluation metrics def compute_metrics(eval_output): - """Compute evaluation metrics (Hit, Precision, Recall, F1) from the output of the inference step. + """ + Compute evaluation metrics (hit, precision, recall, F1) from prediction output. - Args: - eval_output (list): List of dictionaries containing prediction and label information. + Parameters: + eval_output (list): List of dictionaries containing prediction output. + + Returns: + None (prints metrics to console) """ + # Concatenate prediction output into a single DataFrame df = pd.concat([pd.DataFrame(d) for d in eval_output]) - all_hit = [] - all_precision = [] - all_recall = [] - all_f1 = [] - # Loop through each prediction and label pair + # Initialize lists to store metrics + all_hit = [] # List of boolean values indicating whether prediction matches label + all_precision = [] # List of precision values + all_recall = [] # List of recall values + all_f1 = [] # List of F1 values + + # Iterate over prediction-label pairs for pred, label in zip(df.pred.tolist(), df.label.tolist()): try: - # Extract the first prediction (split by '[/s]') and label + # Preprocess prediction string pred = pred.split('[/s]')[0].strip().split('|') - label = label.split('|') - # Compute hit, precision, recall, and F1 score - hit = len(re.findall(pred[0], label)) > 0 + # Check if prediction matches label + hit = re.findall(pred[0], label) + all_hit.append(len(hit) > 0) + + # Compute precision, recall, and F1 + label = label.split('|') matches = set(pred).intersection(set(label)) precision = len(matches) / len(set(pred)) recall = len(matches) / len(set(label)) + + # Handle division by zero if recall + precision == 0: f1 = 0 else: f1 = 2 * precision * recall / (precision + recall) - # Append computed scores to lists - all_hit.append(hit) + # Store metrics all_precision.append(precision) all_recall.append(recall) all_f1.append(f1) except Exception as e: - # Handle exceptions (e.g., when label or prediction is empty) + # Handle exceptions by printing error message and skipping iteration print(f'Label: {label}') print(f'Pred: {pred}') print(f'Exception: {e}') print('------------------') - # Compute average scores + # Compute average metrics hit = sum(all_hit) / len(all_hit) precision = sum(all_precision) / len(all_precision) recall = sum(all_recall) / len(all_recall) f1 = sum(all_f1) / len(all_f1) - # Print average scores + # Print metrics to console print(f'Hit: {hit:.4f}') print(f'Precision: {precision:.4f}') print(f'Recall: {recall:.4f}') print(f'F1: {f1:.4f}') -# Define a function to save model parameters to a file def save_params_dict(model, save_path): - """Save model parameters to a file. + """ + Saves a model's parameters to a file while excluding non-trainable weights. Args: - model (torch.nn.Module): Model to save. - save_path (str): Path to save the model parameters. + model (torch.nn.Module): The model to save parameters from. + save_path (str): The path to save the parameters to. """ + # Get the model's state dictionary, which contains all its parameters state_dict = model.state_dict() + + # Create a dictionary mapping parameter names to their requires_grad status param_grad_dict = { k: v.requires_grad for (k, v) in model.named_parameters() } + + # Remove non-trainable parameters from the state dictionary for k in list(state_dict.keys()): if k in param_grad_dict.keys() and not param_grad_dict[k]: del state_dict[k] # Delete parameters that do not require gradient + + # Save the filtered state dictionary to the specified path torch.save(state_dict, save_path) -# Define a function to load model parameters from a file def load_params_dict(model, save_path): - """Load model parameters from a file. - - Args: - model (torch.nn.Module): Model to load. - save_path (str): Path to load the model parameters from. - """ + # Load the saved model parameters from the specified file path state_dict = torch.load(save_path) + + # Update the model's parameters with the loaded state dictionary model.load_state_dict(state_dict) + + # Return the model with updated parameters return model -# Define a function to compute the loss for a given model and batch -def get_loss(model, batch, model_save_name): - """Compute the loss for a given model and batch. +def get_loss(model, batch, model_save_name: str) -> Tensor: + """ + Compute the loss for a given model and batch of data. Args: - model (torch.nn.Module): Model to compute loss for. - batch (torch_geometric.data.Data): Batch to compute loss for. - model_save_name (str): Name of the model (used to determine loss computation). + model: The model to compute the loss for. + batch: The batch of data to compute the loss for. + model_save_name: The name of the model being used (e.g. 'llm'). Returns: - loss (torch.Tensor): Computed loss. + Tensor: The computed loss. """ + # Check the type of model being used to determine the input arguments if model_save_name == 'llm': + # For LLM models, compute the loss using the question, label, and desc inputs return model(batch.question, batch.label, batch.desc) - else: + else: # (GNN+LLM) return model(batch.question, batch.x, batch.edge_index, batch.batch, batch.label, batch.edge_attr, batch.desc) -# Define a function to perform inference for a given model and batch def inference_step(model, batch, model_save_name): - """Perform inference for a given model and batch. + """ + Performs inference on a given batch of data using the provided model. Args: - model (torch.nn.Module): Model to perform inference with. - batch (torch_geometric.data.Data): Batch to perform inference on. - model_save_name (str): Name of the model (used to determine inference computation). + model (nn.Module): The model to use for inference. + batch: The batch of data to process. + model_save_name (str): The name of the model (e.g. 'llm'). Returns: - pred (torch.Tensor): Predicted output. + The output of the inference step. """ + # Check the type of model being used to determine the input arguments if model_save_name == 'llm': + # Perform inference on the question and textual graph description return model.inference(batch.question, batch.desc) - else: - return model.inference(batch.question, batch.x, batch.edge_index, - batch.batch, batch.edge_attr, batch.desc) + else: # (GNN+LLM) + return model.inference( + batch.question, + batch.x, # node features + batch.edge_index, # edge indices + batch.batch, # batch indices + batch.edge_attr, # edge attributes + batch.desc # description + ) -# Define the training loop def train( - num_epochs, - hidden_channels, - num_gnn_layers, - batch_size, - eval_batch_size, - lr, - checkpointing=False, - tiny_llama=False, + num_epochs, # Total number of training epochs + hidden_channels, # Number of hidden channels in GNN + num_gnn_layers, # Number of GNN layers + batch_size, # Training batch size + eval_batch_size, # Evaluation batch size + lr, # Initial learning rate + checkpointing=False, # Whether to checkpoint model + tiny_llama=False, # Whether to use tiny LLaMA model ): - """Train the model for a specified number of epochs. + """ + Train a GNN+LLM model on WebQSP dataset. Args: - num_epochs (int): Number of epochs to train for. - hidden_channels (int): Number of hidden channels for the GNN model. - num_gnn_layers (int): Number of layers for the GNN model. - batch_size (int): Batch size for training. - eval_batch_size (int): Batch size for evaluation. - lr (float): Learning rate for the optimizer. - checkpointing (bool): Whether to save model checkpoints. Default: False. - tiny_llama (bool): Whether to use the tiny LLaMA model. Default: False. + num_epochs (int): Total number of training epochs. + hidden_channels (int): Number of hidden channels in GNN. + num_gnn_layers (int): Number of GNN layers. + batch_size (int): Training batch size. + eval_batch_size (int): Evaluation batch size. + lr (float): Initial learning rate. + checkpointing (bool, optional): Whether to checkpoint model. Defaults to False. + tiny_llama (bool, optional): Whether to use tiny LLaMA model. Defaults to False. + + Returns: + None """ - # Define the learning rate schedule def adjust_learning_rate(param_group, LR, epoch): - # Decay the learning rate with half-cycle cosine after warmup + """ + Decay learning rate with half-cycle cosine after warmup. + + Args: + param_group (dict): Parameter group. + LR (float): Learning rate. + epoch (int): Current epoch. + + Returns: + float: Adjusted learning rate. + """ min_lr = 5e-6 warmup_epochs = 1 if epoch < warmup_epochs: @@ -194,7 +230,7 @@ def adjust_learning_rate(param_group, LR, epoch): param_group['lr'] = lr return lr - # Initialize training start time + # Start training time start_time = time.time() # Load dataset and create data loaders @@ -204,10 +240,6 @@ def adjust_learning_rate(param_group, LR, epoch): val_dataset = WebQSPDataset(path, split='val') test_dataset = WebQSPDataset(path, split='test') - # Set random seed for reproducibility - seed_everything(42) - - # Create data loaders train_loader = DataLoader(train_dataset, batch_size=batch_size, drop_last=True, pin_memory=True, shuffle=True) val_loader = DataLoader(val_dataset, batch_size=eval_batch_size, @@ -219,7 +251,7 @@ def adjust_learning_rate(param_group, LR, epoch): gc.collect() torch.cuda.empty_cache() - # Initialize GNN model + # Create GNN model gnn = GAT( in_channels=1024, hidden_channels=hidden_channels, @@ -228,7 +260,7 @@ def adjust_learning_rate(param_group, LR, epoch): heads=4, ) - # Initialize LLaMA model + # Create LLaMA model if tiny_llama: llm = LLM( model_name='TinyLlama/TinyLlama-1.1B-Chat-v0.1', @@ -239,12 +271,12 @@ def adjust_learning_rate(param_group, LR, epoch): llm = LLM(model_name='meta-llama/Llama-2-7b-chat-hf', num_params=7) model = GRetriever(llm=llm, gnn=gnn) - # Set model name for saving and loading + # Set model save name model_save_name = 'gnn_llm' if num_gnn_layers != 0 else 'llm' if model_save_name == 'llm': model = llm - # Initialize optimizer and learning rate schedule + # Create optimizer params = [p for _, p in model.named_parameters() if p.requires_grad] optimizer = torch.optim.AdamW([ { @@ -253,57 +285,41 @@ def adjust_learning_rate(param_group, LR, epoch): 'weight_decay': 0.05 }, ], betas=(0.9, 0.95)) - grad_steps = 2 - # Initialize best epoch and validation loss + # Initialize best epoch and best validation loss best_epoch = 0 best_val_loss = float('inf') - # Train for specified number of epochs + # Train model for epoch in range(num_epochs): - # Set model to training mode model.train() - - # Initialize epoch loss epoch_loss = 0 - - # Print preparation time if epoch == 0: print(f"Total Preparation Time: {time.time() - start_time:2f}s") start_time = time.time() print("Training beginning...") - - # Create progress bar for training epoch_str = f'Epoch: {epoch + 1}|{num_epochs}' loader = tqdm(train_loader, desc=epoch_str) - - # Train on each batch for step, batch in enumerate(loader): optimizer.zero_grad() loss = get_loss(model, batch, model_save_name) loss.backward() - # Clip gradients to prevent exploding gradients clip_grad_norm_(optimizer.param_groups[0]['params'], 0.1) - # Update learning rate schedule - if (step + 1) % grad_steps == 0: - lr = adjust_learning_rate(optimizer.param_groups[0], lr, - step / len(train_loader) + epoch) + if (step + 1) % 2 == 0: + adjust_learning_rate(optimizer.param_groups[0], lr, + step / len(train_loader) + epoch) - # Update optimizer optimizer.step() epoch_loss = epoch_loss + float(loss) - # Update learning rate schedule - if (step + 1) % grad_steps == 0: + if (step + 1) % 2 == 0: lr = optimizer.param_groups[0]['lr'] - - # Print training loss train_loss = epoch_loss / len(train_loader) print(epoch_str + f', Train Loss: {train_loss:4f}') - # Evaluate on validation set + # Evaluate model val_loss = 0 eval_output = [] model.eval() @@ -313,8 +329,6 @@ def adjust_learning_rate(param_group, LR, epoch): val_loss += loss.item() val_loss = val_loss / len(val_loader) print(epoch_str + f", Val Loss: {val_loss:4f}") - - # Save model checkpoint if validation loss improves if checkpointing and val_loss < best_val_loss: print("Checkpointing best model...") best_val_loss = val_loss @@ -325,7 +339,7 @@ def adjust_learning_rate(param_group, LR, epoch): torch.cuda.empty_cache() torch.cuda.reset_max_memory_allocated() - # Load best model checkpoint if checkpointing is enabled + # Load best checkpoint if necessary if checkpointing and best_epoch != num_epochs - 1: print("Loading best checkpoint...") model = load_params_dict( @@ -333,7 +347,7 @@ def adjust_learning_rate(param_group, LR, epoch): f'{model_save_name}_best_val_loss_ckpt.pt', ) - # Evaluate on test set + # Evaluate model on test set model.eval() eval_output = [] print("Final evaluation...") @@ -350,19 +364,18 @@ def adjust_learning_rate(param_group, LR, epoch): eval_output.append(eval_data) progress_bar_test.update(1) - # Compute evaluation metrics + # Compute metrics compute_metrics(eval_output) - # Print total training time + # Print final training time print(f"Total Training Time: {time.time() - start_time:2f}s") - # Save final model and evaluation output + # Save model and evaluation output save_params_dict(model, f'{model_save_name}.pt') torch.save(eval_output, f'{model_save_name}_eval_outs.pt') if __name__ == '__main__': - # Parse command-line arguments parser = argparse.ArgumentParser() parser.add_argument('--gnn_hidden_channels', type=int, default=1024) parser.add_argument('--num_gnn_layers', type=int, default=4) @@ -374,10 +387,7 @@ def adjust_learning_rate(param_group, LR, epoch): parser.add_argument('--tiny_llama', action='store_true') args = parser.parse_args() - # Print total time start_time = time.time() - - # Train the model train( args.epochs, args.gnn_hidden_channels, @@ -388,6 +398,4 @@ def adjust_learning_rate(param_group, LR, epoch): checkpointing=args.checkpointing, tiny_llama=args.tiny_llama, ) - - # Print total time print(f"Total Time: {time.time() - start_time:2f}s")