-
Notifications
You must be signed in to change notification settings - Fork 177
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Introduce optuna.artifacts
to the PyTorch checkpoint example
#280
Conversation
optuna.artifacts
to the PyTorch checkpoint example
@nabenabe0928 Could you review this PR? |
This pull request has not seen any recent activity. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for your PR! Could you check my comments?
pytorch/pytorch_checkpoint.py
Outdated
checkpoint = torch.load(checkpoint_path) | ||
if trial_number is not None: | ||
study = optuna.load_study(study_name="pytorch_checkpoint", storage="sqlite:///example.db") | ||
artifact_id = study.trials[trial_number].user_attrs["artifact_id"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the process is terminated before the first checkpoint, the artifact will not be saved, so check if it exists.
pytorch/pytorch_checkpoint.py
Outdated
@@ -159,9 +158,15 @@ def objective(trial): | |||
"optimizer_state_dict": optimizer.state_dict(), | |||
"accuracy": accuracy, | |||
}, | |||
tmp_checkpoint_path, | |||
"./tmp_model.pt", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you change the path of checkpoint for each trial? If we run this script with multi-process, the saved models can be broken by other processes.
This pull request has not seen any recent activity. |
Co-authored-by: Naoto Mizuno <[email protected]>
timeout 5 python pytorch_checkpoint.py The fix could be like this.
|
Thank you for your review! I have fixed it according to your suggestion. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for your update. It's almost LGTM. Could you check my comment?
file_path=f"./tmp_model_{trial.number}.pt", | ||
artifact_id=artifact_id, | ||
) | ||
checkpoint = torch.load(f"./tmp_model_{trial.number}.pt") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you remove the temporary file here?
os.remove(f"./tmp_model_{trial.number}.pt")
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for your comment. I have fix this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
Motivation
Currently, the PyTorch checkpoint example is using local file system to save and manage checkpoints, not yet reflecting the recent
optuna.artifacts
functionalities.Description of the changes
optuna.artifacts
.