From 7950bc18dc3124d3c81b7697731e85c2ff9f0970 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20Yenig=C3=BCl?= Date: Mon, 7 Aug 2023 13:39:48 +0200 Subject: [PATCH] fix: arbitary output size --- compy/models/graphs/pytorch_geom_model.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/compy/models/graphs/pytorch_geom_model.py b/compy/models/graphs/pytorch_geom_model.py index be95a4b..b04dead 100644 --- a/compy/models/graphs/pytorch_geom_model.py +++ b/compy/models/graphs/pytorch_geom_model.py @@ -18,11 +18,11 @@ def __init__(self, config): annotation_size = config["hidden_size_orig"] hidden_size = config["gnn_h_size"] n_steps = config["num_timesteps"] - num_cls = 2 + num_cls = config["num_cls"] self.reduce = nn.Linear(annotation_size, hidden_size) self.conv = GatedGraphConv(hidden_size, n_steps) - self.agg = GlobalAttention(nn.Linear(hidden_size, 1), nn.Linear(hidden_size, 2)) + self.agg = GlobalAttention(nn.Linear(hidden_size, 1), nn.Linear(hidden_size, num_cls)) self.lin = nn.Linear(hidden_size, num_cls) def forward( @@ -51,6 +51,7 @@ def __init__(self, config=None, num_types=None): "learning_rate": 0.001, "batch_size": 64, "num_epochs": 1000, + "num_cls": 2 } super().__init__(config)