forked from Project-MONAI/model-zoo
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Enhance Deepedit bundle to enable finetune and early stopping (Projec…
…t-MONAI#504) ### Description This PR is used to enable early stopping and finetune for the spleen deepedit annotation bundle. ### Status **Ready** ### Please ensure all the checkboxes: <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Codeformat tests passed locally by running `./runtests.sh --codeformat`. - [ ] In-line docstrings updated. - [ ] Update `version` and `changelog` in `metadata.json` if changing an existing bundle. - [ ] Please ensure the naming rules in config files meet our requirements (please refer to: `CONTRIBUTING.md`). - [ ] Ensure versions of packages such as `monai`, `pytorch` and `numpy` are correct in `metadata.json`. - [ ] Descriptions should be consistent with the content, such as `eval_metrics` of the provided weights and TorchScript modules. - [ ] Files larger than 25MB are excluded and replaced by providing download links in `large_file.yml`. - [ ] Avoid using path that contains personal information within config files (such as use `/home/your_name/` for `"bundle_root"`). --------- Signed-off-by: Yiheng Wang <[email protected]>
- Loading branch information
1 parent
df9bab7
commit 2a90cff
Showing
7 changed files
with
48 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .early_stop_score_function import score_function |
15 changes: 15 additions & 0 deletions
15
models/spleen_deepedit_annotation/scripts/early_stop_score_function.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
import os | ||
|
||
import torch | ||
import torch.distributed as dist | ||
|
||
|
||
def score_function(engine): | ||
val_metric = engine.state.metrics["val_mean_dice"] | ||
if dist.is_initialized(): | ||
device = torch.device("cuda:" + os.environ["LOCAL_RANK"]) | ||
val_metric = torch.tensor([val_metric]).to(device) | ||
dist.all_reduce(val_metric, op=dist.ReduceOp.SUM) | ||
val_metric /= dist.get_world_size() | ||
return val_metric.item() | ||
return val_metric |