Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce optuna.artifacts to the PyTorch checkpoint example #280

Merged
merged 9 commits into from
Oct 25, 2024

Conversation

kAIto47802
Copy link
Contributor

Motivation

Currently, the PyTorch checkpoint example is using local file system to save and manage checkpoints, not yet reflecting the recent optuna.artifacts functionalities.

Description of the changes

  • Introduced optuna.artifacts.
  • Removed the use of local file system.

@kAIto47802 kAIto47802 changed the title Introduce artifact store to the PyTorch checkpoint example Introduce optuna.artifacts to the PyTorch checkpoint example Sep 24, 2024
@HideakiImamura
Copy link
Member

@nabenabe0928 Could you review this PR?

Copy link

github-actions bot commented Oct 7, 2024

This pull request has not seen any recent activity.

@github-actions github-actions bot added the stale Exempt from stale bot labeling. label Oct 7, 2024
@not522 not522 assigned not522 and unassigned nabenabe0928 Oct 9, 2024
Copy link
Member

@not522 not522 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your PR! Could you check my comments?

pytorch/pytorch_checkpoint.py Outdated Show resolved Hide resolved
checkpoint = torch.load(checkpoint_path)
if trial_number is not None:
study = optuna.load_study(study_name="pytorch_checkpoint", storage="sqlite:///example.db")
artifact_id = study.trials[trial_number].user_attrs["artifact_id"]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the process is terminated before the first checkpoint, the artifact will not be saved, so check if it exists.

@@ -159,9 +158,15 @@ def objective(trial):
"optimizer_state_dict": optimizer.state_dict(),
"accuracy": accuracy,
},
tmp_checkpoint_path,
"./tmp_model.pt",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you change the path of checkpoint for each trial? If we run this script with multi-process, the saved models can be broken by other processes.

@github-actions github-actions bot removed the stale Exempt from stale bot labeling. label Oct 9, 2024
Copy link

This pull request has not seen any recent activity.

@github-actions github-actions bot added the stale Exempt from stale bot labeling. label Oct 17, 2024
@nabenabe0928 nabenabe0928 removed the stale Exempt from stale bot labeling. label Oct 18, 2024
@not522
Copy link
Member

not522 commented Oct 22, 2024

retried_trial_number returns the first trial's number in the retry history, so we should check the entire retry history. For example, running the following command can check the behavior.

timeout 5 python pytorch_checkpoint.py

The fix could be like this.

$ git diff
diff --git a/pytorch/pytorch_checkpoint.py b/pytorch/pytorch_checkpoint.py
index 35b697d..e7f2b5d 100644
--- a/pytorch/pytorch_checkpoint.py
+++ b/pytorch/pytorch_checkpoint.py
@@ -89,20 +89,25 @@ def objective(trial):
     lr = trial.suggest_float("lr", 1e-5, 1e-1, log=True)
     optimizer = getattr(optim, optimizer_name)(model.parameters(), lr=lr)
 
-    trial_number = RetryFailedTrialCallback.retried_trial_number(trial)
-
-    artifact_id = trial_number and trial.study.trials[trial_number].user_attrs.get("artifact_id")
-    if trial_number is not None and artifact_id is not None:
+    artifact_id = None
+    retry_history = RetryFailedTrialCallback.retry_history(trial)
+    for trial_number in reversed(retry_history):
+        artifact_id = trial.study.trials[trial_number].user_attrs.get("artifact_id")
+        if artifact_id is not None:
+            retry_trial_number = trial_number
+            break
+
+    if artifact_id is not None:
         download_artifact(
             artifact_store=artifact_store,
-            file_path=f"./tmp_model_{trial_number}.pt",
+            file_path=f"./tmp_model_{trial.number}.pt",
             artifact_id=artifact_id,
         )
-        checkpoint = torch.load(f"./tmp_model_{trial_number}.pt")
+        checkpoint = torch.load(f"./tmp_model_{trial.number}.pt")
         epoch = checkpoint["epoch"]
         epoch_begin = epoch + 1
 
-        print(f"Loading a checkpoint from trial {trial_number} in epoch {epoch}.")
+        print(f"Loading a checkpoint from trial {retry_trial_number} in epoch {epoch}.")
 
         model.load_state_dict(checkpoint["model_state_dict"])
         optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
@@ -159,15 +164,15 @@ def objective(trial):
                 "optimizer_state_dict": optimizer.state_dict(),
                 "accuracy": accuracy,
             },
-            f"./tmp_model_{trial_number}.pt",
+            f"./tmp_model_{trial.number}.pt",
         )
         artifact_id = upload_artifact(
             artifact_store=artifact_store,
-            file_path=f"./tmp_model_{trial_number}.pt",
+            file_path=f"./tmp_model_{trial.number}.pt",
             study_or_trial=trial,
         )
         trial.set_user_attr("artifact_id", artifact_id)
-        os.remove(f"./tmp_model_{trial_number}.pt")
+        os.remove(f"./tmp_model_{trial.number}.pt")
 
         # Handle pruning based on the intermediate value.
         if trial.should_prune():

@kAIto47802
Copy link
Contributor Author

Thank you for your review! I have fixed it according to your suggestion.

Copy link
Member

@not522 not522 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your update. It's almost LGTM. Could you check my comment?

file_path=f"./tmp_model_{trial.number}.pt",
artifact_id=artifact_id,
)
checkpoint = torch.load(f"./tmp_model_{trial.number}.pt")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you remove the temporary file here?

os.remove(f"./tmp_model_{trial.number}.pt")

Copy link
Contributor Author

@kAIto47802 kAIto47802 Oct 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your comment. I have fix this.

Copy link
Member

@not522 not522 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@not522 not522 added this to the v4.1.0 milestone Oct 25, 2024
@not522 not522 merged commit 8158f46 into optuna:main Oct 25, 2024
6 checks passed
@nabenabe0928 nabenabe0928 added the enhancement Change that does not break compatibility and not affect public interfaces, but improves performance label Oct 30, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement Change that does not break compatibility and not affect public interfaces, but improves performance
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants