Skip to content

Commit

Permalink
Merge pull request #331 from neuralaudio/add-s-metric
Browse files Browse the repository at this point in the history
Add event based metrics
  • Loading branch information
turian authored Nov 20, 2021
2 parents 04fa248 + 3622310 commit f14889d
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 6 deletions.
37 changes: 34 additions & 3 deletions heareval/predictions/task_predictions.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,7 @@ def __init__(
scores: List[ScoreFunction],
validation_target_events: Dict[str, List[Dict[str, Any]]],
test_target_events: Dict[str, List[Dict[str, Any]]],
postprocessing_grid: Dict[str, List[float]],
conf: Dict,
):
super().__init__(
Expand All @@ -378,6 +379,7 @@ def __init__(
}
# For each epoch, what postprocessing parameters were best
self.epoch_best_postprocessing: Dict[int, Tuple[Tuple[str, Any], ...]] = {}
self.postprocessing_grid = postprocessing_grid

def _score_epoch_end(self, name: str, outputs: List[Dict[str, List[Any]]]):
flat_outputs = self._flatten_batched_outputs(
Expand Down Expand Up @@ -410,7 +412,12 @@ def _score_epoch_end(self, name: str, outputs: List[Dict[str, List[Any]]]):
# print("\n\n\n", epoch)

predicted_events_by_postprocessing = get_events_for_all_files(
prediction, filename, timestamp, self.idx_to_label, postprocessing_cached
prediction,
filename,
timestamp,
self.idx_to_label,
self.postprocessing_grid,
postprocessing_cached,
)

score_and_postprocessing = []
Expand Down Expand Up @@ -617,6 +624,7 @@ def get_events_for_all_files(
filenames: List[str],
timestamps: torch.Tensor,
idx_to_label: Dict[int, str],
postprocessing_grid: Dict[str, List[float]],
postprocessing: Optional[Tuple[Tuple[str, Any], ...]] = None,
) -> Dict[Tuple[Tuple[str, Any], ...], Dict[str, List[Dict[str, Union[str, float]]]]]:
"""
Expand All @@ -630,7 +638,8 @@ def get_events_for_all_files(
If no postprocessing is specified (during training), we try a
variety of ways of postprocessing the predictions into events,
including median filtering and minimum event length.
from the postprocessing_grid including median filtering and
minimum event length.
If postprocessing is specified (during test, chosen at the best
validation epoch), we use this postprocessing.
Expand Down Expand Up @@ -679,7 +688,7 @@ def get_events_for_all_files(
timestamp_predictions, idx_to_label, **dict(postprocess)
)
else:
postprocessing_confs = list(ParameterGrid(EVENT_POSTPROCESSING_GRID))
postprocessing_confs = list(ParameterGrid(postprocessing_grid))
for postprocess_dict in tqdm(postprocessing_confs):
postprocess = tuple(postprocess_dict.items())
event_dict[postprocess] = {}
Expand Down Expand Up @@ -865,6 +874,18 @@ def _combine_target_events(split_names: List[str]):
validation_target_events: Dict = _combine_target_events(data_splits["valid"])
test_target_events: Dict = _combine_target_events(data_splits["test"])

# The postprocessing search space for getting the
# best task specific postprocessing, can be task
# specific, present in the task metadata in
# evaluation_params.postprocessing_grid. If not, the default
# EVENT_POSTPROCESSING_GRID will be used.
if "event_postprocessing_grid" in metadata.get("evaluation_params", {}):
postprocessing_grid = metadata["evaluation_params"][
"event_postprocessing_grid"
]
else:
postprocessing_grid = EVENT_POSTPROCESSING_GRID

predictor = EventPredictionModel(
nfeatures=embedding_size,
label_to_idx=label_to_idx,
Expand All @@ -873,6 +894,7 @@ def _combine_target_events(split_names: List[str]):
scores=scores,
validation_target_events=validation_target_events,
test_target_events=test_target_events,
postprocessing_grid=postprocessing_grid,
conf=conf,
)
elif metadata["embedding_type"] == "scene":
Expand Down Expand Up @@ -1212,9 +1234,18 @@ def task_predictions(
)

# Update with task specific grid parameters
# From the global TASK_SPECIFIC_PARAM_GRID
if metadata["task_name"] in TASK_SPECIFIC_PARAM_GRID:
final_grid.update(TASK_SPECIFIC_PARAM_GRID[metadata["task_name"]])

# From task specific parameter grid in the task metadata
# We add this option, so that task specific param grid can be used
# for secret tasks, without mentioning them in the global
# TASK_SPECIFIC_PARAM_GRID. Ideally one out of the two option should be
# there
if "task_specific_param_grid" in metadata.get("evaluation_params", {}):
final_grid.update(metadata["evaluation_params"]["task_specific_param_grid"])

# Model selection
confs = list(ParameterGrid(final_grid))
random.shuffle(confs)
Expand Down
18 changes: 15 additions & 3 deletions heareval/score.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,11 +310,23 @@ def __call__(self, predictions: np.ndarray, targets: np.ndarray, **kwargs) -> fl
EventBasedScore,
name="event_onset_200ms_fms",
score="f_measure",
params={"evaluate_onset": True, "evaluate_offset": False, "t_collar": 0.2},
),
"event_onset_50ms_fms": partial(
EventBasedScore,
name="event_onset_50ms_fms",
score="f_measure",
params={"evaluate_onset": True, "evaluate_offset": False, "t_collar": 0.05},
),
"event_onset_offset_50ms_20perc_fms": partial(
EventBasedScore,
name="event_onset_offset_50ms_20perc_fms",
score="f_measure",
params={
"evaluate_onset": True,
"evaluate_offset": False,
"t_collar": 0.2,
"percentage_of_length": 0.5,
"evaluate_offset": True,
"t_collar": 0.05,
"percentage_of_length": 0.2,
},
),
"segment_1s_er": partial(
Expand Down

0 comments on commit f14889d

Please sign in to comment.