-
Notifications
You must be signed in to change notification settings - Fork 0
/
predict.py
105 lines (86 loc) · 3.79 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
from comet_ml.exceptions import InterruptedExperiment
import asyncio
import torch
import hydra
import json
import shutil
from models.cnn.unet.model import UnetModel
from datasets.wildfire_data_module import WildfireDataModule
from pathlib import Path
from loguru import logger
from omegaconf import DictConfig
from predictors.map_predictor import MapPredictor
from logging_utils.logging import setup_logger
from boundaries.canada_boundary import CanadaBoundary
from data_sources.canada_boundary_data_source import CanadaBoundaryDataSource
@hydra.main(version_base=None, config_path="config", config_name="predict")
def main(cfg: DictConfig):
run_name = cfg["run"]["name"]
debug = cfg["debug"]
setup_logger(logger, run_name, debug)
logger.info(f"Run name: {run_name}")
logger.info(f"Debug : {debug}")
run_output_path = Path(cfg["output_path"]) / Path(cfg["run"]["name"])
run_output_path.mkdir(parents=True, exist_ok=True)
predict_tmp_output_path = run_output_path / Path("tmp")
predict_tmp_output_path.mkdir(parents=True, exist_ok=True)
predict_final_output_path = run_output_path / Path("maps")
predict_final_output_path.mkdir(parents=True, exist_ok=True)
logger.info("Loading split info...")
with open(Path(cfg["data"]["split_info_file_path"]), "r") as f:
split_info = json.load(f)
train_stats = split_info["train_stats"]
predict_input_data_folder_path = Path(cfg["data"]["input_data_folder_path"])
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Device: {device}")
logger.info("Creating data module...")
data_module = WildfireDataModule(
input_data_indexes_to_remove=cfg["data"]["input_data_indexes_to_remove"],
eval_batch_size=cfg["predict"]["batch_size"],
input_data_no_data_value=cfg["data"]["input_data_no_data_value"],
input_data_new_no_data_value=cfg["data"]["input_data_new_no_data_value"],
predict_folder_path=predict_input_data_folder_path,
train_stats=train_stats,
data_loading_num_workers=cfg["data"]["data_loading_num_workers"],
device=device,
)
data_module.setup(stage="predict")
predict_dl = data_module.predict_dataloader()
model = UnetModel(
in_channels=cfg["model"]["number_of_input_channels"]
- len(cfg["data"]["input_data_indexes_to_remove"]),
nb_classes=cfg["model"]["number_of_classes"],
activation_fn_name=cfg["model"]["activation_fn_name"],
num_encoder_decoder_blocks=cfg["model"]["num_encoder_decoder_blocks"],
use_batchnorm=cfg["model"]["use_batchnorm"],
)
logger.info("Loading trained model...")
model.load_state_dict(
torch.load(Path(cfg["model"]["trained_model_path"]), weights_only=True),
strict=True,
)
convert_model_output_to_probabilities = cfg["predict"][
"convert_model_output_to_probabilities"
]
canada_boundary = CanadaBoundary(
data_source=CanadaBoundaryDataSource(output_path=predict_tmp_output_path),
target_epsg=cfg["data"]["target_srid"],
)
canada_boundary.load(list(cfg["data"]["provinces"]))
predictor = MapPredictor(
model=model,
device=device,
output_folder_path=predict_tmp_output_path,
canada_boundary=canada_boundary,
convert_model_output_to_probabilities=convert_model_output_to_probabilities,
)
try:
final_map_output_path = asyncio.run(predictor.predict(predict_dl))
shutil.move(final_map_output_path, predict_final_output_path)
logger.success(f"Predictions saved at: {predict_final_output_path}")
shutil.rmtree(predict_tmp_output_path)
except InterruptedExperiment as exc:
logger.info("status", str(exc))
logger.info("Experiment interrupted!")
if __name__ == "__main__":
main()