diff --git a/AutoCap/settings/pretraining.yaml b/AutoCap/settings/pretraining.yaml index b20482b..e7dd4c4 100644 --- a/AutoCap/settings/pretraining.yaml +++ b/AutoCap/settings/pretraining.yaml @@ -149,7 +149,7 @@ audio_encoder_args: text_decoder_args: model_tag: "audio_qformer" name: "facebook/bart-base" - pretrained: false + pretrained: True freeze: True freeze_embed_layer: True bert_args: diff --git a/AutoCap/src/models/pl_htsat_q_bart_captioning.py b/AutoCap/src/models/pl_htsat_q_bart_captioning.py index aa29eed..90f964b 100644 --- a/AutoCap/src/models/pl_htsat_q_bart_captioning.py +++ b/AutoCap/src/models/pl_htsat_q_bart_captioning.py @@ -550,6 +550,7 @@ def shift_tokens_right(self, input_ids: torch.Tensor, pad_token_id: int, decoder def forward_encoder(self, audios): outputs = self.encoder(audios) + print("HTSAT embedding", outputs.last_hidden_state.shape) outputs = self.enc_to_dec_proj(outputs.last_hidden_state) # dropout @@ -737,6 +738,7 @@ def generate(self, ): encoder_outputs = self.forward_encoder(samples) + print("audio_embeds", encoder_outputs.shape) attn_mask = torch.ones(encoder_outputs.size()[:-1], dtype=torch.long, device=encoder_outputs.device) if self.use_audio_qformer: @@ -1034,49 +1036,54 @@ def on_validation_epoch_end(self): if len(captions_pred) == 0 or len(captions_gt) == 0: continue - metrics = evaluate_metrics(captions_pred, captions_gt, nb_reference_captions=5, bert_model=self.bert_model, exclude_metrics=self.exclude_metrics) - - - def get_score(metrics, key): - if key in metrics: - return float(metrics[key]['score']) - else: - return 0 - - spider = get_score(metrics, 'spider') - cider = get_score(metrics, 'cider') - spice = get_score(metrics, 'spice') - bleu_1 = get_score(metrics, 'bleu_1') - bleu_4 = get_score(metrics, 'bleu_4') - rouge_l = get_score(metrics, 'rouge_l') - meteor = get_score(metrics, 'meteor') - - val_logger.info(f'Cider: {cider:7.4f}') - val_logger.info( - f'Spider score using beam search (beam size:{beam_size}): {spider:7.4f}') - metrics_log = {f"{split}/spider_beam_{beam_size}" : spider, - f"{split}/cider_beam_{beam_size}":cider, - f"{split}/spice_beam_{beam_size}":spice, - f"{split}/bleu_1_beam_{beam_size}":bleu_1, - f"{split}/bleu_4_beam_{beam_size}":bleu_4, - f"{split}/rouge_l_beam_{beam_size}":rouge_l, - f"{split}/meteor_beam_{beam_size}":meteor } - if 'bert_score' in metrics: - bert_score = metrics.pop('bert_score') - metrics_log[f"{split}/bertscore_beam_{beam_size}"] = bert_score - val_logger.info(f"Bert score {bert_score}") - - self.log_dict(metrics_log, - prog_bar=True, - logger=True, - on_step=False, - on_epoch=True, - sync_dist=True) - - for metric, values in metrics.items(): - val_logger.info(f'beam search (size {beam_size}): {metric:<7s}: {values["score"]:7.4f}') - + try: + metrics = evaluate_metrics(captions_pred, captions_gt, nb_reference_captions=5, bert_model=self.bert_model, exclude_metrics=self.exclude_metrics) + + + def get_score(metrics, key): + if key in metrics: + return float(metrics[key]['score']) + else: + return 0 + + spider = get_score(metrics, 'spider') + cider = get_score(metrics, 'cider') + spice = get_score(metrics, 'spice') + bleu_1 = get_score(metrics, 'bleu_1') + bleu_4 = get_score(metrics, 'bleu_4') + rouge_l = get_score(metrics, 'rouge_l') + meteor = get_score(metrics, 'meteor') + + val_logger.info(f'Cider: {cider:7.4f}') + val_logger.info( + f'Spider score using beam search (beam size:{beam_size}): {spider:7.4f}') + + metrics_log = {f"{split}/spider_beam_{beam_size}" : spider, + f"{split}/cider_beam_{beam_size}":cider, + f"{split}/spice_beam_{beam_size}":spice, + f"{split}/bleu_1_beam_{beam_size}":bleu_1, + f"{split}/bleu_4_beam_{beam_size}":bleu_4, + f"{split}/rouge_l_beam_{beam_size}":rouge_l, + f"{split}/meteor_beam_{beam_size}":meteor } + if 'bert_score' in metrics: + bert_score = metrics.pop('bert_score') + metrics_log[f"{split}/bertscore_beam_{beam_size}"] = bert_score + val_logger.info(f"Bert score {bert_score}") + + self.log_dict(metrics_log, + prog_bar=True, + logger=True, + on_step=False, + on_epoch=True, + sync_dist=True) + + for metric, values in metrics.items(): + val_logger.info(f'beam search (size {beam_size}): {metric:<7s}: {values["score"]:7.4f}') + except Exception as e: + print("Error while calculating the metrics.") + metrics_log = {} + self.log("time/val_epoch", time.time() - self.val_start_time, on_step=False, on_epoch=True, logger=True) return metrics_log diff --git a/AutoCap/src/modules/audio_encoder/audio_encoder.py b/AutoCap/src/modules/audio_encoder/audio_encoder.py index f1364fd..840f2a4 100644 --- a/AutoCap/src/modules/audio_encoder/audio_encoder.py +++ b/AutoCap/src/modules/audio_encoder/audio_encoder.py @@ -21,6 +21,7 @@ class AudioEncoderModel(PreTrainedModel): def __init__(self, config): super(AudioEncoderModel, self).__init__(config) + self.representation = config.representation if config.model_arch == "cnn": if config.model_name == 'ResNet38': self.audio_enc = ResNet38(config) @@ -79,7 +80,8 @@ def forward(self, input_ids, output_hidden_states=False, return_dict=True ): - audio_embeds = self.audio_enc(input_ids) + print("Using audio representation:", self.representation) + audio_embeds = self.audio_enc(input_ids, representation=self.representation) if not return_dict: return (audio_embeds, ) return BaseModelOutput(audio_embeds, None, None) diff --git a/AutoCap/src/modules/audio_encoder/audio_encoder_config.py b/AutoCap/src/modules/audio_encoder/audio_encoder_config.py index d811267..a3bf0f8 100644 --- a/AutoCap/src/modules/audio_encoder/audio_encoder_config.py +++ b/AutoCap/src/modules/audio_encoder/audio_encoder_config.py @@ -22,6 +22,7 @@ def __init__(self, freeze: bool = False, spec_augment: bool = True, audio_args: dict = None, + representation = 'fine_grained_embedding', **kwargs): super(AudioEncoderConfig, self).__init__(**kwargs) if model_arch not in ["cnn", "transformer"]: @@ -36,5 +37,7 @@ def __init__(self, self.hidden_size = 1024 if model_arch == "cnn" else 768 self.spec_augment = spec_augment self.audio_args = audio_args + self.representation = representation + print("config rep", self.representation ) self.num_labels = 0 \ No newline at end of file diff --git a/AutoCap/src/modules/audio_encoder/htsat.py b/AutoCap/src/modules/audio_encoder/htsat.py index f442580..c6634ba 100644 --- a/AutoCap/src/modules/audio_encoder/htsat.py +++ b/AutoCap/src/modules/audio_encoder/htsat.py @@ -836,13 +836,21 @@ def forward_features(self, x): x = x.reshape(B, C, F // c_freq_bin, c_freq_bin, T) x = x.permute(0, 1, 3, 2, 4).contiguous().reshape(B, C, c_freq_bin, -1) + latent_output_64_reshaped = x.reshape(B, C, -1) + latent_output_64_reshaped = latent_output_64_reshaped.permute(0, 2, 1).contiguous() + fine_grained_latent_output = torch.mean(x, dim=2) + pooled_represetnation = fine_grained_latent_output fine_grained_latent_output = interpolate(fine_grained_latent_output.permute(0, 2, 1).contiguous(), 8 * self.patch_stride[1]) - + + fine_grained_latent_output_256 = interpolate(pooled_represetnation.permute(0, 2, 1).contiguous(), + 2 * self.patch_stride[1]) + fine_grained_latent_output_32 = pooled_represetnation.permute(0, 2, 1).contiguous() # get latent_output latent_output = self.avgpool(torch.flatten(x, 2)) latent_output = torch.flatten(latent_output, 1) + # display the attention map, if needed # if self.config.htsat_attn_heatmap: @@ -886,6 +894,10 @@ def forward_features(self, x): 'framewise_output': fpx, # already sigmoided 'clipwise_output': torch.sigmoid(x), 'fine_grained_embedding': fine_grained_latent_output, + "fine_grained_latent_output_32":fine_grained_latent_output_32, + "fine_grained_latent_output_256":fine_grained_latent_output_256, + "latent_output_64_reshaped":latent_output_64_reshaped, + "latent_output": latent_output, 'embedding': latent_output } @@ -915,7 +927,6 @@ def reshape_wav2img(self, x): x = nn.functional.interpolate(x, (x.shape[2], target_F), mode="bicubic", align_corners=True) x = x.permute(0, 1, 3, 2).contiguous() x = x.reshape(x.shape[0], x.shape[1], x.shape[2], self.freq_ratio, x.shape[3] // self.freq_ratio) - # print(x.shape) x = x.permute(0, 1, 3, 2, 4).contiguous() x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3], x.shape[4]) return x @@ -936,7 +947,7 @@ def repeat_wat2img(self, x, cur_pos): x = x.repeat(repeats=(1, 1, 4, 1)) return x - def forward(self, input: torch.Tensor, infer_mode=False): # out_feat_keys: List[str] = None): + def forward(self, input: torch.Tensor, infer_mode=False, representation='fine_grained_embedding'): # out_feat_keys: List[str] = None): # x = self.spectrogram_extractor(x) # (batch_size, 1, time_steps, freq_bins) # x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins) @@ -951,7 +962,7 @@ def forward(self, input: torch.Tensor, infer_mode=False): # out_feat_keys: List x = self.reshape_wav2img(x) output_dict = self.forward_features(x) # x = self.head(x) - return output_dict["fine_grained_embedding"] + return output_dict[representation] if __name__ == '__main__': diff --git a/dataset_preperation/README.md b/dataset_preperation/README.md index fb53726..555e899 100644 --- a/dataset_preperation/README.md +++ b/dataset_preperation/README.md @@ -16,18 +16,28 @@ We introduce an efficient pipeline for collecting ambient audio. It starts by an For initializing your environment, please refer to the [general README](../README.md). ## Autocap Dataset Download - - -**datasets will be coming later!** +- We currently provide the following datasets: + + * **AutoReCapXL:** containing more than **47M** audio-text pairs, filtered to have LAION CLAP similaity above 0.1 + * **AutoReCapXL-MQ:** containing more than **20.7M** audio-text pairs, filtered to have LAION CLAP similaity above 0.4 + * **AutoReCapXL-MQ-L:** containing more than **20.7M** audio-text pairs, filtered to have LAION CLAP similaity above 0.4 and audio clips longer than 5 seconds. + * **AutoReCapXL-HQ:** containing more than **10.7M** audio-text pairs, filtered to have LAION CLAP similaity above 0.5. + +AutoReCap datasets are derived from Youtube videos. The datasets contain mainly ambinet audio clips and few speech and music clips. Please refer to the paper for more details on this dataset. These datasets can be filtered based on specified CLAP similarity thresholds and minimum audio clip lengths as described below. ```shell python download.py --save_dir --dataset_name +# Example +python download.py --save_dir data/datasets/autocap --dataset_name AutoReCapXL-HQ --audio_only + +# Example of filtering according to clap similarity and audio clip length +python download.py --save_dir data/datasets/autocap --dataset_name AutoReCapXL --clap_threshold 0.4 --min_audio_len 5 --audio_only + +# Example of downloading only a subset of the datasets +python download.py --save_dir data/datasets/autocap --dataset_name AutoReCapXL-HQ --start_idx 0 --end_idx 100000 --audio_only ``` - - + By default, the script will download videos along with their metadata. diff --git a/dataset_preperation/download.py b/dataset_preperation/download.py index 072582b..83fa88e 100644 --- a/dataset_preperation/download.py +++ b/dataset_preperation/download.py @@ -129,8 +129,11 @@ def update_interval_dict(dict_1, dict_2): def read_video_segments_info(local_input_video_segments, start_idx=0, - end_idx=int(1e9)): + end_idx=int(1e9), + min_audio_len=0.0, + clap_threshold=0.0): all_video_segments = {} + total_number_of_clips = 0 with open(local_input_video_segments, 'r') as f: last_idx = 0 for idx, json_str in enumerate(tqdm(f, desc="parsing json input")): @@ -141,8 +144,21 @@ def read_video_segments_info(local_input_video_segments, json_str = json_str[:-1] if json_str.endswith(','): json_str = json_str[:-1] - json_object = json.loads(json_str) - update_interval_dict(all_video_segments, json_object) + + data = json.loads(json_str) + video_ids = list(data.keys()) + if len(video_ids) == 0: + continue + video_id = video_ids[0] + + intervals = data[video_id].get("intervals", []) + len_intervals_filtered = [clip for clip in intervals if float(clip['end']) - float(clip['start'])>= min_audio_len] + clap_len_intervals_filtered = [clip for clip in len_intervals_filtered if clip.get("CLAP_SIM", -9999) is not None and clip.get("CLAP_SIM", -9999) >= clap_threshold] + total_number_of_clips += len(clap_len_intervals_filtered) + video_data = {} + video_data[video_id] = {} + video_data[video_id]['intervals'] = clap_len_intervals_filtered + update_interval_dict(all_video_segments, video_data) except Exception as e: print("[ERROR] Couldn't parse json string:", json_str) continue @@ -151,6 +167,7 @@ def read_video_segments_info(local_input_video_segments, if last_idx >= end_idx: break + print(f"Found {total_number_of_clips} audio clips.") return all_video_segments def download_audioset_split(json_file, @@ -163,14 +180,18 @@ def download_audioset_split(json_file, end_idx=int(1e9), num_processes=os.cpu_count(), resume=True, - files_per_folder=5000 + files_per_folder=5000, + clap_threshold=0.4, + min_audio_len=4, ): os.makedirs(save_dir, exist_ok=True) all_video_segments = read_video_segments_info(json_file, start_idx=start_idx, - end_idx=end_idx) + end_idx=end_idx, + min_audio_len=min_audio_len, + clap_threshold=clap_threshold) download_audio_split = partial(download_yt_video, save_dir=save_dir, @@ -208,6 +229,18 @@ def download_audioset_split(json_file, required=True, help=f"Provided the dataset names. Available datasets are {dataset_urls.keys()}") + parser.add_argument("--clap_threshold", + type=float, + required=False, + default=0.4, + help=f"Provided the clap similarity threshold to filter the dataset, default: 0.4") + + parser.add_argument("--min_audio_len", + type=float, + required=False, + default=4, + help=f"Provided the minimum audio clip length to filter the dataset, default: 4s") + parser.add_argument("--input_file", type=str, default=None, @@ -269,5 +302,7 @@ def download_audioset_split(json_file, proxy_port=args.proxy, start_idx=args.start_idx, end_idx=args.end_idx, + clap_threshold=args.clap_threshold, + min_audio_len=args.min_audio_len, resume=not args.redownload, files_per_folder=args.files_per_folder) diff --git a/dataset_preperation/download_manager.py b/dataset_preperation/download_manager.py index 815f764..76ad085 100644 --- a/dataset_preperation/download_manager.py +++ b/dataset_preperation/download_manager.py @@ -2,8 +2,11 @@ import wget save_dir = 'data/json_files' -dataset_urls = {} - +dataset_urls = {"AutoReCapXL":'https://huggingface.co/datasets/mali6/autocap/resolve/main/processed_snap-hdvila100m-w-clap-videos_segments_filtered_above_01.json', + "AutoReCapXL-MQ": 'https://huggingface.co/datasets/mali6/autocap/resolve/main/processed_snap-hdvila100m-w-clap-videos_segments_filtered_above_04.json', + "AutoReCapXL-MQ-L": 'https://huggingface.co/datasets/mali6/autocap/resolve/main/processed_snap-hdvila100m-w-clap-videos_segments_filtered_above_04_longer_than_5s.json', + "AutoReCapXL-HQ": 'https://huggingface.co/datasets/mali6/autocap/resolve/main/processed_snap-hdvila100m-w-clap-videos_segments_filtered_above_05.json'} + def get_dataset_json_file(dataset_name, dataset_json_file_path=None, download=True): if dataset_json_file_path is None: