diff --git a/examples/llm/g_retriever.py b/examples/llm/g_retriever.py index e0a57bf3edde..02faca26b577 100644 --- a/examples/llm/g_retriever.py +++ b/examples/llm/g_retriever.py @@ -185,14 +185,14 @@ def inference_step(model, batch, model_save_name): def train( - num_epochs, # Total number of training epochs - hidden_channels, # Number of hidden channels in GNN - num_gnn_layers, # Number of GNN layers - batch_size, # Training batch size - eval_batch_size, # Evaluation batch size - lr, # Initial learning rate - checkpointing=False, # Whether to checkpoint model - tiny_llama=False, # Whether to use tiny LLaMA model + num_epochs, # Total number of training epochs + hidden_channels, # Number of hidden channels in GNN + num_gnn_layers, # Number of GNN layers + batch_size, # Training batch size + eval_batch_size, # Evaluation batch size + lr, # Initial learning rate + checkpointing=False, # Whether to checkpoint model + tiny_llama=False, # Whether to use tiny LLaMA model ): """Train a GNN+LLM model on WebQSP dataset.