From 8927de272f54e749efdc31379d0986bf2ade714e Mon Sep 17 00:00:00 2001 From: Alexander Kensert Date: Thu, 19 Sep 2024 17:22:03 +0200 Subject: [PATCH] fix edge_softmax implementation --- molexpress/ops/gnn_ops.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/molexpress/ops/gnn_ops.py b/molexpress/ops/gnn_ops.py index 13bebdc..8e009dd 100644 --- a/molexpress/ops/gnn_ops.py +++ b/molexpress/ops/gnn_ops.py @@ -87,14 +87,12 @@ def aggregate( return node_state_updated def edge_softmax(score, edge_dst): - numerator = keras.ops.exp(score - keras.ops.max(score, axis=0, keepdims=True)) - num_segments = keras.ops.max(edge_dst) + 1 + num_segments = keras.ops.maximum(keras.ops.max(edge_dst) + 1, 1) + score_max = keras.ops.segment_max(score, edge_dst, num_segments, sorted=False) + score_max = gather(score_max, edge_dst) + numerator = keras.ops.exp(score - score_max) denominator = keras.ops.segment_sum(numerator, edge_dst, num_segments, sorted=False) - expected_rank = len(keras.ops.shape(denominator)) - current_rank = len(keras.ops.shape(edge_dst)) - for _ in range(expected_rank - current_rank): - edge_dst = keras.ops.expand_dims(edge_dst, axis=-1) - denominator = keras.ops.take_along_axis(denominator, edge_dst, axis=0) + denominator = gather(denominator, edge_dst) return numerator / denominator def gather(