Skip to content

Commit

Permalink
fix edge_softmax implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
akensert committed Sep 19, 2024
1 parent 6cede57 commit 8927de2
Showing 1 changed file with 5 additions and 7 deletions.
12 changes: 5 additions & 7 deletions molexpress/ops/gnn_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,14 +87,12 @@ def aggregate(
return node_state_updated

def edge_softmax(score, edge_dst):
numerator = keras.ops.exp(score - keras.ops.max(score, axis=0, keepdims=True))
num_segments = keras.ops.max(edge_dst) + 1
num_segments = keras.ops.maximum(keras.ops.max(edge_dst) + 1, 1)
score_max = keras.ops.segment_max(score, edge_dst, num_segments, sorted=False)
score_max = gather(score_max, edge_dst)
numerator = keras.ops.exp(score - score_max)
denominator = keras.ops.segment_sum(numerator, edge_dst, num_segments, sorted=False)
expected_rank = len(keras.ops.shape(denominator))
current_rank = len(keras.ops.shape(edge_dst))
for _ in range(expected_rank - current_rank):
edge_dst = keras.ops.expand_dims(edge_dst, axis=-1)
denominator = keras.ops.take_along_axis(denominator, edge_dst, axis=0)
denominator = gather(denominator, edge_dst)
return numerator / denominator

def gather(
Expand Down

0 comments on commit 8927de2

Please sign in to comment.