Skip to content

Commit

Permalink
Update JAX trainer predict() to reduce memory consumption
Browse files Browse the repository at this point in the history
  • Loading branch information
fchollet committed Jan 24, 2024
1 parent f97e3c7 commit d9c69a0
Showing 1 changed file with 12 additions and 0 deletions.
12 changes: 12 additions & 0 deletions keras/backend/jax/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,6 +632,9 @@ def append_to_outputs(batch_outputs, outputs):
non_trainable_variables = [
v.value for v in self.non_trainable_variables
]
self._purge_model_variables(
optimizer_variables=False, metric_variables=False
)
state = (trainable_variables, non_trainable_variables)
outputs = None
for step, x in epoch_iterator.enumerate_epoch():
Expand All @@ -641,7 +644,16 @@ def append_to_outputs(batch_outputs, outputs):
callbacks.on_predict_batch_end(step, {"outputs": batch_outputs})
if self.stop_predicting:
break

self._jax_state = {
# I wouldn't recommend modifying non-trainable model state
# during predict(), but it's allowed.
"trainable_variables": state[0],
"non_trainable_variables": state[1],
}
self.jax_state_sync()
callbacks.on_predict_end()
self._jax_state = None
return tree.map_structure_up_to(batch_outputs, np.concatenate, outputs)

def train_on_batch(
Expand Down

0 comments on commit d9c69a0

Please sign in to comment.