From 8156646cdef0c753162fc6b68a11d57cce642f95 Mon Sep 17 00:00:00 2001 From: spktrm Date: Thu, 14 Dec 2023 07:30:15 +1000 Subject: [PATCH] fix mean logit calculation in neurd loss for rnad --- open_spiel/python/algorithms/rnad/rnad.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/open_spiel/python/algorithms/rnad/rnad.py b/open_spiel/python/algorithms/rnad/rnad.py index 1cadfaffa8..5d3d51b210 100644 --- a/open_spiel/python/algorithms/rnad/rnad.py +++ b/open_spiel/python/algorithms/rnad/rnad.py @@ -562,6 +562,9 @@ def get_loss_nerd(logit_list: Sequence[chex.Array], """Define the nerd loss.""" assert isinstance(importance_sampling_correction, list) loss_pi_list = [] + + num_valid_actions = jnp.sum(legal_actions, axis=-1, keepdims=True) + for k, (logit_pi, pi, q_vr, is_c) in enumerate( zip(logit_list, policy_list, q_vr_list, importance_sampling_correction)): assert logit_pi.shape[0] == q_vr.shape[0] @@ -570,9 +573,12 @@ def get_loss_nerd(logit_list: Sequence[chex.Array], adv_pi = is_c * adv_pi # importance sampling correction adv_pi = jnp.clip(adv_pi, a_min=-clip, a_max=clip) adv_pi = lax.stop_gradient(adv_pi) - - logits = logit_pi - jnp.mean( - logit_pi * legal_actions, axis=-1, keepdims=True) + + valid_logit_sum = jnp.sum(logit_pi * legal_actions, axis=-1, keepdims=True) + mean_logit = valid_logit_sum / num_valid_actions + + # Subtract only the mean of the valid logits + logits = logit_pi - mean_logit threshold_center = jnp.zeros_like(logits)