Skip to content

Commit

Permalink
Adds a "debug" mode to Baselines, in order to support decoding state.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 700497596
  • Loading branch information
Petar Veličković authored and copybara-github committed Nov 27, 2024
1 parent a891bc2 commit 06cc87b
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 5 deletions.
23 changes: 19 additions & 4 deletions clrs/_src/baselines.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ def __init__(
hint_repred_mode: str = 'soft',
name: str = 'base_model',
nb_msg_passing_steps: int = 1,
debug: bool = False,
):
"""Constructor for BaselineModel.
Expand Down Expand Up @@ -199,6 +200,7 @@ def __init__(
- 'hard_on_eval', which is soft for training and hard for evaluation.
name: Model name.
nb_msg_passing_steps: Number of message passing steps per hint.
debug: If True, the model run in debug mode, outputting all hidden state.
Raises:
ValueError: if `encode_hints=True` and `decode_hints=False`.
Expand All @@ -223,6 +225,7 @@ def __init__(
self.opt = optax.adam(learning_rate)

self.nb_msg_passing_steps = nb_msg_passing_steps
self.debug = debug

self.nb_dims = []
if isinstance(dummy_trajectory, _Feedback):
Expand Down Expand Up @@ -253,7 +256,8 @@ def _use_net(*args, **kwargs):
processor_factory, use_lstm, encoder_init,
dropout_prob, hint_teacher_forcing,
hint_repred_mode,
self.nb_dims, self.nb_msg_passing_steps)(*args, **kwargs)
self.nb_dims, self.nb_msg_passing_steps,
self.debug)(*args, **kwargs)

self.net_fn = hk.transform(_use_net)
pmap_args = dict(axis_name='batch', devices=jax.local_devices())
Expand Down Expand Up @@ -324,18 +328,25 @@ def _feedback(self, params, rng_key, feedback, opt_state, algorithm_index):
def _predict(self, params, rng_key: hk.PRNGSequence, features: _Features,
algorithm_index: int, return_hints: bool,
return_all_outputs: bool):
outs, hint_preds = self.net_fn.apply(
net_outputs = self.net_fn.apply(
params, rng_key, [features],
repred=True, algorithm_index=algorithm_index,
return_hints=return_hints,
return_all_outputs=return_all_outputs)
if self.debug:
outs, hint_preds, hidden_states = net_outputs
else:
outs, hint_preds = net_outputs
outs = decoders.postprocess(self._spec[algorithm_index],
outs,
sinkhorn_temperature=0.1,
sinkhorn_steps=50,
hard=True,
)
return outs, hint_preds
if self.debug:
return outs, hint_preds, hidden_states
else:
return outs, hint_preds

def compute_grad(
self,
Expand Down Expand Up @@ -394,12 +405,16 @@ def predict(self, rng_key: hk.PRNGSequence, features: _Features,

def _loss(self, params, rng_key, feedback, algorithm_index):
"""Calculates model loss f(feedback; params)."""
output_preds, hint_preds = self.net_fn.apply(
outputs = self.net_fn.apply(
params, rng_key, [feedback.features],
repred=False,
algorithm_index=algorithm_index,
return_hints=True,
return_all_outputs=False)
if self.debug:
output_preds, hint_preds, _ = outputs
else:
output_preds, hint_preds = outputs

nb_nodes = _nb_nodes(feedback, is_chunked=False)
lengths = feedback.features.lengths
Expand Down
8 changes: 7 additions & 1 deletion clrs/_src/nets.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def __init__(
hint_repred_mode='soft',
nb_dims=None,
nb_msg_passing_steps=1,
debug=False,
name: str = 'net',
):
"""Constructs a `Net`."""
Expand All @@ -102,6 +103,7 @@ def __init__(
self.use_lstm = use_lstm
self.encoder_init = encoder_init
self.nb_msg_passing_steps = nb_msg_passing_steps
self.debug = debug

def _msg_passing_step(self,
mp_state: _MessagePassingScanState,
Expand Down Expand Up @@ -186,7 +188,7 @@ def _msg_passing_step(self,
accum_mp_state = _MessagePassingScanState( # pytype: disable=wrong-arg-types # numpy-scalars
hint_preds=hint_preds if return_hints else None,
output_preds=output_preds if return_all_outputs else None,
hiddens=None, lstm_state=None)
hiddens=hiddens if self.debug else None, lstm_state=None)

# Complying to jax.scan, the first returned value is the state we carry over
# the second value is the output that will be stacked over steps.
Expand Down Expand Up @@ -318,6 +320,10 @@ def invert(d):
output_preds = output_mp_state.output_preds
hint_preds = invert(accum_mp_state.hint_preds)

if self.debug:
hiddens = jnp.stack([v for v in accum_mp_state.hiddens])
return output_preds, hint_preds, hiddens

return output_preds, hint_preds

def _construct_encoders_decoders(self):
Expand Down

0 comments on commit 06cc87b

Please sign in to comment.