Skip to content

Commit

Permalink
Fix merge error when loading state dicts
Browse files Browse the repository at this point in the history
  • Loading branch information
Christoph Gerum committed Jan 26, 2024
1 parent bdb905d commit d12925f
Show file tree
Hide file tree
Showing 5 changed files with 2,197 additions and 2,329 deletions.
2 changes: 1 addition & 1 deletion external/hannah-tvm
Submodule hannah-tvm updated 1 files
+1 −1 pyproject.toml
4 changes: 2 additions & 2 deletions hannah/models/factory/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,8 @@ class NetworkConfig:

_target_: str = "hannah.models.factory.create_cnn"
name: str = MISSING
norm: Optional[NormConfig] = BNConfig()
act: Optional[ActConfig] = ActConfig()
norm: Optional[NormConfig] = field(default_factory=BNConfig)
act: Optional[ActConfig] = field(default_factory=ActConfig)
qconfig: Optional[Any] = None
conv: List[MajorBlockConfig] = field(default_factory=list)
linear: List[LinearConfig] = field(default_factory=list)
Expand Down
19 changes: 19 additions & 0 deletions hannah/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,25 @@ def train(
_convert_="partial",
)

if config.get("input_file", None):
msglogger.info("Loading initial weights from model %s", config.input_file)
lit_module.setup("train")
lit_module.load_from_state_dict(config.input_file, strict=False)

if config["auto_lr"]:
# run lr finder (counts as one epoch)
lr_finder = lit_trainer.lr_find(lit_module)

# inspect results
fig = lr_finder.plot()
fig.savefig("./learning_rate.png")

# recreate module with updated config
suggested_lr = lr_finder.suggestion()
config["lr"] = suggested_lr

lit_trainer.tune(lit_module)

logging.info("Starting training")
# PL TRAIN
ckpt_path = None
Expand Down
Loading

0 comments on commit d12925f

Please sign in to comment.