Skip to content

Commit

Permalink
Allow parsing trainer options with additional gradient cuts (#23)
Browse files Browse the repository at this point in the history
* Add options to cut bptt and diverted chain

* Add test for parsing the new options
  • Loading branch information
Ceyron authored Oct 30, 2024
1 parent 9ead97a commit e6b3027
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 2 deletions.
17 changes: 15 additions & 2 deletions apebench/_base_scenario.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,26 +349,39 @@ def get_trainer(self, *, train_config: str):
)
elif train_args[0].lower() == "sup":
num_rollout_steps = int(train_args[1])
if len(train_args) > 2:
cut_bptt = train_args[2].lower() == "true"
else:
cut_bptt = False
trainer = tx.trainer.SupervisedTrainer(
train_trjs,
optimizer=optimizer,
num_training_steps=self.num_training_steps,
batch_size=self.batch_size,
num_rollout_steps=num_rollout_steps,
cut_bptt=False,
cut_bptt=cut_bptt,
time_level_weights=None,
callback_fn=callback_fn,
)
elif train_args[0].lower() == "div":
num_rollout_steps = int(train_args[1])
if len(train_args) > 2:
cut_bptt = train_args[2].lower() == "true"
else:
cut_bptt = False
if len(train_args) > 3:
cut_div_chain = train_args[3].lower() == "true"
else:
cut_div_chain = False
trainer = tx.trainer.DivertedChainBranchOneTrainer(
train_trjs,
ref_stepper=ref_stepper,
optimizer=optimizer,
num_training_steps=self.num_training_steps,
batch_size=self.batch_size,
num_rollout_steps=num_rollout_steps,
cut_bptt=False,
cut_bptt=cut_bptt,
cut_div_chain=cut_div_chain,
time_level_weights=None,
callback_fn=callback_fn,
)
Expand Down
65 changes: 65 additions & 0 deletions tests/test_trainer_parsing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import apebench


def test_trainer_parsing():
advection_scenario = apebench.scenarios.difficulty.Advection()

trainer = advection_scenario.get_trainer(
train_config="one",
)

assert trainer.loss_configuration.num_rollout_steps == 1
del trainer

trainer = advection_scenario.get_trainer(train_config="sup;3")

assert trainer.loss_configuration.num_rollout_steps == 3
assert trainer.loss_configuration.cut_bptt is False
del trainer

trainer = advection_scenario.get_trainer(train_config="sup;3;True")

assert trainer.loss_configuration.num_rollout_steps == 3
assert trainer.loss_configuration.cut_bptt is True
del trainer

trainer = advection_scenario.get_trainer(train_config="sup;3;False")

assert trainer.loss_configuration.num_rollout_steps == 3
assert trainer.loss_configuration.cut_bptt is False
del trainer

trainer = advection_scenario.get_trainer(train_config="div;3")

assert trainer.loss_configuration.num_rollout_steps == 3
assert trainer.loss_configuration.cut_bptt is False
assert trainer.loss_configuration.cut_div_chain is False
del trainer

trainer = advection_scenario.get_trainer(train_config="div;3;True")

assert trainer.loss_configuration.num_rollout_steps == 3
assert trainer.loss_configuration.cut_bptt is True
assert trainer.loss_configuration.cut_div_chain is False
del trainer

trainer = advection_scenario.get_trainer(train_config="div;3;False")

assert trainer.loss_configuration.num_rollout_steps == 3
assert trainer.loss_configuration.cut_bptt is False
assert trainer.loss_configuration.cut_div_chain is False
del trainer

trainer = advection_scenario.get_trainer(train_config="div;3;True;True")

assert trainer.loss_configuration.num_rollout_steps == 3
assert trainer.loss_configuration.cut_bptt is True
assert trainer.loss_configuration.cut_div_chain is True
del trainer

trainer = advection_scenario.get_trainer(train_config="div;3;False;True")

assert trainer.loss_configuration.num_rollout_steps == 3
assert trainer.loss_configuration.cut_bptt is False
assert trainer.loss_configuration.cut_div_chain is True
del trainer

0 comments on commit e6b3027

Please sign in to comment.