diff --git a/examples/llm/g_retriever.py b/examples/llm/g_retriever.py index 18fc4bc7eeb3..1b615a165eab 100644 --- a/examples/llm/g_retriever.py +++ b/examples/llm/g_retriever.py @@ -149,8 +149,15 @@ def get_loss(model, batch, model_save_name: str) -> Tensor: # For LLM models, compute the loss using the question, label, and desc inputs return model(batch.question, batch.label, batch.desc) else: # (GNN+LLM) - return model(batch.question, batch.x, batch.edge_index, batch.batch, - batch.label, batch.edge_attr, batch.desc) + return model( + batch.question, + batch.x, # node features + batch.edge_index, # edge indices + batch.batch, # batch indices + batch.edge_attr, # edge attributes + batch.label, # answers (labels) + batch.desc # description + ) def inference_step(model, batch, model_save_name):