Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Save Predictions + Add prediction logging #335

Merged
merged 8 commits into from
Nov 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 27 additions & 4 deletions heareval/predictions/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import json
import sys
import time
import logging
from pathlib import Path
from typing import Any, List

Expand All @@ -18,6 +19,24 @@
from heareval.predictions.task_predictions import task_predictions


def get_logger(task_name: str, log_path: Path) -> logging.Logger:
"""Returns a task level logger"""
logger = logging.getLogger(task_name)
logger.setLevel(logging.INFO)
fh = logging.FileHandler(log_path)
fh.setLevel(logging.INFO)
ch = logging.StreamHandler()
ch.setLevel(logging.INFO)
# create formatter and add it to the handlers
formatter = logging.Formatter("TASK - %(name)s - prediction %(message)s")
ch.setFormatter(formatter)
fh.setFormatter(formatter)
# add the handlers to logger
logger.addHandler(ch)
logger.addHandler(fh)
return logger


@click.command()
@click.argument(
"task_dirs",
Expand Down Expand Up @@ -79,10 +98,13 @@ def runner(
# We already did this
continue

print(f"Computing predictions for {task_path.name}")

# Get embedding sizes for all splits/folds
metadata = json.load(task_path.joinpath("task_metadata.json").open())

log_path = task_path.joinpath("prediction.log")
logger = get_logger(task_name=metadata["task_name"], log_path=log_path)

logger.info(f"Computing predictions for {task_path.name}")
embedding_sizes = []
for split in metadata["splits"]:
split_path = task_path.joinpath(f"{split}.embedding-dimensions.json")
Expand All @@ -104,11 +126,12 @@ def runner(
in_memory=in_memory,
deterministic=deterministic,
grid=grid,
logger=logger,
)
khumairraj marked this conversation as resolved.
Show resolved Hide resolved
sys.stdout.flush()
gpu_max_mem_used = gpu_max_mem.measure()
print(
f"DONE. took {time.time() - start} seconds to complete task_predictions"
logger.info(
f"DONE took {time.time() - start} seconds to complete task_predictions"
f"(embedding_path={task_path}, embedding_size={embedding_size}, "
f"grid_points={grid_points}, gpus={gpus}, "
f"gpu_max_mem_used={gpu_max_mem_used}, "
Expand Down
60 changes: 40 additions & 20 deletions heareval/predictions/task_predictions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import copy
import json
import logging
import math
import multiprocessing
import pickle
Expand Down Expand Up @@ -333,7 +334,11 @@ def _score_epoch_end(self, name: str, outputs: List[Dict[str, List[Any]]]):

if name == "test":
# Cache all predictions for later serialization
self.test_predicted_labels = prediction
self.test_predictions = {
"target": target.detach().cpu(),
"prediction": prediction.detach().cpu(),
"prediction_logit": prediction_logit.detach().cpu(),
}

for score in self.scores:
end_scores[f"{name}_{score}"] = score(
Expand Down Expand Up @@ -449,10 +454,14 @@ def _score_epoch_end(self, name: str, outputs: List[Dict[str, List[Any]]]):
end_scores[f"{name}_loss"] = self.predictor.logit_loss(prediction_logit, target)

if name == "test":
# print("test epoch", self.current_epoch)
# Cache all predictions for later serialization
self.test_predicted_labels = prediction
self.test_predicted_events = predicted_events
self.test_predictions = {
"target": target.detach().cpu(),
"prediction": prediction.detach().cpu(),
"prediction_logit": prediction_logit.detach().cpu(),
"target_events": self.target_events[name],
"predicted_events": predicted_events,
}

for score in self.scores:
end_scores[f"{name}_{score}"] = score(
Expand Down Expand Up @@ -792,6 +801,7 @@ def dataloader_from_split_name(
class GridPointResult:
def __init__(
self,
predictor,
model_path: str,
epoch: int,
time_in_min: float,
Expand All @@ -802,6 +812,7 @@ def __init__(
score_mode: str,
conf: Dict,
):
self.predictor = predictor
self.model_path = model_path
self.epoch = epoch
self.time_in_min = time_in_min
Expand Down Expand Up @@ -977,6 +988,7 @@ def _combine_target_events(split_names: List[str]):
logger.finalize("success")
logger.save()
return GridPointResult(
predictor=predictor,
model_path=checkpoint_callback.best_model_path,
epoch=epoch,
time_in_min=time_in_min,
Expand Down Expand Up @@ -1068,7 +1080,6 @@ def data_splits_from_folds(folds: List[str]) -> List[Dict[str, List[str]]]:
{
"train": train_folds,
"valid": [valid_fold],
"train+valid": train_folds + [valid_fold],
"test": [test_fold],
}
)
Expand Down Expand Up @@ -1181,10 +1192,14 @@ def sort_grid_points(
return grid_point_results


def print_scores(grid_point_results: List[GridPointResult], embedding_path: Path):
def print_scores(
grid_point_results: List[GridPointResult],
embedding_path: Path,
logger: logging.Logger,
):
grid_point_results = sort_grid_points(grid_point_results)
for g in grid_point_results:
print(g, str(embedding_path))
logger.info(f"Grid Point Summary: {g}")


def task_predictions(
Expand All @@ -1195,6 +1210,7 @@ def task_predictions(
in_memory: bool,
deterministic: bool,
grid: str,
logger: logging.Logger,
):
# By setting workers=True in seed_everything(), Lightning derives
# unique seeds across all dataloader workers and processes
Expand Down Expand Up @@ -1252,7 +1268,7 @@ def task_predictions(

grid_point_results = []
for conf in tqdm(confs[:grid_points], desc="grid"):
print("trying grid point", conf)
logger.info(f"Trying Grid Point: {conf}")
grid_point_result = task_predictions_train(
embedding_path=embedding_path,
embedding_size=embedding_size,
Expand All @@ -1267,24 +1283,24 @@ def task_predictions(
deterministic=deterministic,
)
grid_point_results.append(grid_point_result)
print_scores(grid_point_results, embedding_path)
print_scores(grid_point_results, embedding_path, logger)

# Use the best hyperparameters to train models for remaining folds,
# then compute test scores using the resulting models
grid_point_results = sort_grid_points(grid_point_results)
best_grid_point = grid_point_results[0]
print(
"Best validation score",
best_grid_point.validation_score,
best_grid_point.hparams,
embedding_path,
logger.info(
"Best Grid Point Validation Score: "
f"{best_grid_point.validation_score} "
"Grid Point HyperParams: "
f"{best_grid_point.hparams} "
)

# Train predictors for the remaining splits using the hyperparameters selected
# from the grid search.
split_grid_points = [best_grid_point]
for split in data_splits[1:]:
print(f"Training split: {split}")
logger.info(f"Training Split: {split}")
grid_point_result = task_predictions_train(
embedding_path=embedding_path,
embedding_size=embedding_size,
Expand All @@ -1299,10 +1315,9 @@ def task_predictions(
deterministic=deterministic,
)
split_grid_points.append(grid_point_result)
print(
f"Split {split} validation score: ",
grid_point_result.validation_score,
embedding_path,
logger.info(
f"Validation Score for the Training Split: "
f"{grid_point_result.validation_score}"
)

# Now test each of the trained models
Expand All @@ -1319,6 +1334,11 @@ def task_predictions(
in_memory=in_memory,
)

# Cache predictions for detailed analysis
prediction_file = embedding_path.joinpath(f"{test_fold_str}.predictions.pkl")
with open(prediction_file, "wb") as fp:
pickle.dump(split_grid_points[i].predictor.test_predictions, fp)

# Add model training values relevant to this split model
test_results[test_fold_str].update(
{
Expand Down Expand Up @@ -1349,7 +1369,7 @@ def task_predictions(
open(embedding_path.joinpath("test.predicted-scores.json"), "wt").write(
json.dumps(test_results, indent=4)
)
print("TEST RESULTS", json.dumps(test_results))
logger.info(f"Final Test Results: {json.dumps(test_results)}")

# We no longer have best_predictor, the predictor is
# loaded by trainer.test and then disappears
Expand Down