diff --git a/clrs/_src/baselines.py b/clrs/_src/baselines.py index e481a92..fe8647e 100644 --- a/clrs/_src/baselines.py +++ b/clrs/_src/baselines.py @@ -155,6 +155,7 @@ def __init__( hint_repred_mode: str = 'soft', name: str = 'base_model', nb_msg_passing_steps: int = 1, + debug: bool = False, ): """Constructor for BaselineModel. @@ -199,6 +200,7 @@ def __init__( - 'hard_on_eval', which is soft for training and hard for evaluation. name: Model name. nb_msg_passing_steps: Number of message passing steps per hint. + debug: If True, the model run in debug mode, outputting all hidden state. Raises: ValueError: if `encode_hints=True` and `decode_hints=False`. @@ -223,6 +225,7 @@ def __init__( self.opt = optax.adam(learning_rate) self.nb_msg_passing_steps = nb_msg_passing_steps + self.debug = debug self.nb_dims = [] if isinstance(dummy_trajectory, _Feedback): @@ -253,7 +256,8 @@ def _use_net(*args, **kwargs): processor_factory, use_lstm, encoder_init, dropout_prob, hint_teacher_forcing, hint_repred_mode, - self.nb_dims, self.nb_msg_passing_steps)(*args, **kwargs) + self.nb_dims, self.nb_msg_passing_steps, + self.debug)(*args, **kwargs) self.net_fn = hk.transform(_use_net) pmap_args = dict(axis_name='batch', devices=jax.local_devices()) @@ -324,18 +328,25 @@ def _feedback(self, params, rng_key, feedback, opt_state, algorithm_index): def _predict(self, params, rng_key: hk.PRNGSequence, features: _Features, algorithm_index: int, return_hints: bool, return_all_outputs: bool): - outs, hint_preds = self.net_fn.apply( + net_outputs = self.net_fn.apply( params, rng_key, [features], repred=True, algorithm_index=algorithm_index, return_hints=return_hints, return_all_outputs=return_all_outputs) + if self.debug: + outs, hint_preds, hidden_states = net_outputs + else: + outs, hint_preds = net_outputs outs = decoders.postprocess(self._spec[algorithm_index], outs, sinkhorn_temperature=0.1, sinkhorn_steps=50, hard=True, ) - return outs, hint_preds + if self.debug: + return outs, hint_preds, hidden_states + else: + return outs, hint_preds def compute_grad( self, @@ -394,12 +405,16 @@ def predict(self, rng_key: hk.PRNGSequence, features: _Features, def _loss(self, params, rng_key, feedback, algorithm_index): """Calculates model loss f(feedback; params).""" - output_preds, hint_preds = self.net_fn.apply( + outputs = self.net_fn.apply( params, rng_key, [feedback.features], repred=False, algorithm_index=algorithm_index, return_hints=True, return_all_outputs=False) + if self.debug: + output_preds, hint_preds, _ = outputs + else: + output_preds, hint_preds = outputs nb_nodes = _nb_nodes(feedback, is_chunked=False) lengths = feedback.features.lengths diff --git a/clrs/_src/nets.py b/clrs/_src/nets.py index 56f329b..bfcd6b9 100644 --- a/clrs/_src/nets.py +++ b/clrs/_src/nets.py @@ -85,6 +85,7 @@ def __init__( hint_repred_mode='soft', nb_dims=None, nb_msg_passing_steps=1, + debug=False, name: str = 'net', ): """Constructs a `Net`.""" @@ -102,6 +103,7 @@ def __init__( self.use_lstm = use_lstm self.encoder_init = encoder_init self.nb_msg_passing_steps = nb_msg_passing_steps + self.debug = debug def _msg_passing_step(self, mp_state: _MessagePassingScanState, @@ -186,7 +188,7 @@ def _msg_passing_step(self, accum_mp_state = _MessagePassingScanState( # pytype: disable=wrong-arg-types # numpy-scalars hint_preds=hint_preds if return_hints else None, output_preds=output_preds if return_all_outputs else None, - hiddens=None, lstm_state=None) + hiddens=hiddens if self.debug else None, lstm_state=None) # Complying to jax.scan, the first returned value is the state we carry over # the second value is the output that will be stacked over steps. @@ -318,6 +320,10 @@ def invert(d): output_preds = output_mp_state.output_preds hint_preds = invert(accum_mp_state.hint_preds) + if self.debug: + hiddens = jnp.stack([v for v in accum_mp_state.hiddens]) + return output_preds, hint_preds, hiddens + return output_preds, hint_preds def _construct_encoders_decoders(self):