Skip to content

Commit

Permalink
add datasets
Browse files Browse the repository at this point in the history
  • Loading branch information
MoayedHajiAli committed Dec 23, 2024
1 parent 57b9c8d commit 51ab220
Show file tree
Hide file tree
Showing 8 changed files with 133 additions and 62 deletions.
2 changes: 1 addition & 1 deletion AutoCap/settings/pretraining.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ audio_encoder_args:
text_decoder_args:
model_tag: "audio_qformer"
name: "facebook/bart-base"
pretrained: false
pretrained: True
freeze: True
freeze_embed_layer: True
bert_args:
Expand Down
91 changes: 49 additions & 42 deletions AutoCap/src/models/pl_htsat_q_bart_captioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,6 +550,7 @@ def shift_tokens_right(self, input_ids: torch.Tensor, pad_token_id: int, decoder

def forward_encoder(self, audios):
outputs = self.encoder(audios)
print("HTSAT embedding", outputs.last_hidden_state.shape)
outputs = self.enc_to_dec_proj(outputs.last_hidden_state)

# dropout
Expand Down Expand Up @@ -737,6 +738,7 @@ def generate(self,
):

encoder_outputs = self.forward_encoder(samples)
print("audio_embeds", encoder_outputs.shape)
attn_mask = torch.ones(encoder_outputs.size()[:-1], dtype=torch.long, device=encoder_outputs.device)

if self.use_audio_qformer:
Expand Down Expand Up @@ -1034,49 +1036,54 @@ def on_validation_epoch_end(self):

if len(captions_pred) == 0 or len(captions_gt) == 0:
continue
metrics = evaluate_metrics(captions_pred, captions_gt, nb_reference_captions=5, bert_model=self.bert_model, exclude_metrics=self.exclude_metrics)


def get_score(metrics, key):
if key in metrics:
return float(metrics[key]['score'])
else:
return 0

spider = get_score(metrics, 'spider')
cider = get_score(metrics, 'cider')
spice = get_score(metrics, 'spice')
bleu_1 = get_score(metrics, 'bleu_1')
bleu_4 = get_score(metrics, 'bleu_4')
rouge_l = get_score(metrics, 'rouge_l')
meteor = get_score(metrics, 'meteor')

val_logger.info(f'Cider: {cider:7.4f}')
val_logger.info(
f'Spider score using beam search (beam size:{beam_size}): {spider:7.4f}')

metrics_log = {f"{split}/spider_beam_{beam_size}" : spider,
f"{split}/cider_beam_{beam_size}":cider,
f"{split}/spice_beam_{beam_size}":spice,
f"{split}/bleu_1_beam_{beam_size}":bleu_1,
f"{split}/bleu_4_beam_{beam_size}":bleu_4,
f"{split}/rouge_l_beam_{beam_size}":rouge_l,
f"{split}/meteor_beam_{beam_size}":meteor }
if 'bert_score' in metrics:
bert_score = metrics.pop('bert_score')
metrics_log[f"{split}/bertscore_beam_{beam_size}"] = bert_score
val_logger.info(f"Bert score {bert_score}")

self.log_dict(metrics_log,
prog_bar=True,
logger=True,
on_step=False,
on_epoch=True,
sync_dist=True)

for metric, values in metrics.items():
val_logger.info(f'beam search (size {beam_size}): {metric:<7s}: {values["score"]:7.4f}')

try:
metrics = evaluate_metrics(captions_pred, captions_gt, nb_reference_captions=5, bert_model=self.bert_model, exclude_metrics=self.exclude_metrics)


def get_score(metrics, key):
if key in metrics:
return float(metrics[key]['score'])
else:
return 0

spider = get_score(metrics, 'spider')
cider = get_score(metrics, 'cider')
spice = get_score(metrics, 'spice')
bleu_1 = get_score(metrics, 'bleu_1')
bleu_4 = get_score(metrics, 'bleu_4')
rouge_l = get_score(metrics, 'rouge_l')
meteor = get_score(metrics, 'meteor')

val_logger.info(f'Cider: {cider:7.4f}')
val_logger.info(
f'Spider score using beam search (beam size:{beam_size}): {spider:7.4f}')

metrics_log = {f"{split}/spider_beam_{beam_size}" : spider,
f"{split}/cider_beam_{beam_size}":cider,
f"{split}/spice_beam_{beam_size}":spice,
f"{split}/bleu_1_beam_{beam_size}":bleu_1,
f"{split}/bleu_4_beam_{beam_size}":bleu_4,
f"{split}/rouge_l_beam_{beam_size}":rouge_l,
f"{split}/meteor_beam_{beam_size}":meteor }
if 'bert_score' in metrics:
bert_score = metrics.pop('bert_score')
metrics_log[f"{split}/bertscore_beam_{beam_size}"] = bert_score
val_logger.info(f"Bert score {bert_score}")

self.log_dict(metrics_log,
prog_bar=True,
logger=True,
on_step=False,
on_epoch=True,
sync_dist=True)

for metric, values in metrics.items():
val_logger.info(f'beam search (size {beam_size}): {metric:<7s}: {values["score"]:7.4f}')
except Exception as e:
print("Error while calculating the metrics.")
metrics_log = {}

self.log("time/val_epoch", time.time() - self.val_start_time, on_step=False, on_epoch=True, logger=True)
return metrics_log

Expand Down
4 changes: 3 additions & 1 deletion AutoCap/src/modules/audio_encoder/audio_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class AudioEncoderModel(PreTrainedModel):
def __init__(self, config):
super(AudioEncoderModel, self).__init__(config)

self.representation = config.representation
if config.model_arch == "cnn":
if config.model_name == 'ResNet38':
self.audio_enc = ResNet38(config)
Expand Down Expand Up @@ -79,7 +80,8 @@ def forward(self, input_ids,
output_hidden_states=False,
return_dict=True
):
audio_embeds = self.audio_enc(input_ids)
print("Using audio representation:", self.representation)
audio_embeds = self.audio_enc(input_ids, representation=self.representation)
if not return_dict:
return (audio_embeds, )
return BaseModelOutput(audio_embeds, None, None)
Expand Down
3 changes: 3 additions & 0 deletions AutoCap/src/modules/audio_encoder/audio_encoder_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def __init__(self,
freeze: bool = False,
spec_augment: bool = True,
audio_args: dict = None,
representation = 'fine_grained_embedding',
**kwargs):
super(AudioEncoderConfig, self).__init__(**kwargs)
if model_arch not in ["cnn", "transformer"]:
Expand All @@ -36,5 +37,7 @@ def __init__(self,
self.hidden_size = 1024 if model_arch == "cnn" else 768
self.spec_augment = spec_augment
self.audio_args = audio_args
self.representation = representation
print("config rep", self.representation )
self.num_labels = 0

19 changes: 15 additions & 4 deletions AutoCap/src/modules/audio_encoder/htsat.py
Original file line number Diff line number Diff line change
Expand Up @@ -836,13 +836,21 @@ def forward_features(self, x):
x = x.reshape(B, C, F // c_freq_bin, c_freq_bin, T)
x = x.permute(0, 1, 3, 2, 4).contiguous().reshape(B, C, c_freq_bin, -1)

latent_output_64_reshaped = x.reshape(B, C, -1)
latent_output_64_reshaped = latent_output_64_reshaped.permute(0, 2, 1).contiguous()

fine_grained_latent_output = torch.mean(x, dim=2)
pooled_represetnation = fine_grained_latent_output
fine_grained_latent_output = interpolate(fine_grained_latent_output.permute(0, 2, 1).contiguous(),
8 * self.patch_stride[1])


fine_grained_latent_output_256 = interpolate(pooled_represetnation.permute(0, 2, 1).contiguous(),
2 * self.patch_stride[1])
fine_grained_latent_output_32 = pooled_represetnation.permute(0, 2, 1).contiguous()
# get latent_output
latent_output = self.avgpool(torch.flatten(x, 2))
latent_output = torch.flatten(latent_output, 1)


# display the attention map, if needed
# if self.config.htsat_attn_heatmap:
Expand Down Expand Up @@ -886,6 +894,10 @@ def forward_features(self, x):
'framewise_output': fpx, # already sigmoided
'clipwise_output': torch.sigmoid(x),
'fine_grained_embedding': fine_grained_latent_output,
"fine_grained_latent_output_32":fine_grained_latent_output_32,
"fine_grained_latent_output_256":fine_grained_latent_output_256,
"latent_output_64_reshaped":latent_output_64_reshaped,
"latent_output": latent_output,
'embedding': latent_output
}

Expand Down Expand Up @@ -915,7 +927,6 @@ def reshape_wav2img(self, x):
x = nn.functional.interpolate(x, (x.shape[2], target_F), mode="bicubic", align_corners=True)
x = x.permute(0, 1, 3, 2).contiguous()
x = x.reshape(x.shape[0], x.shape[1], x.shape[2], self.freq_ratio, x.shape[3] // self.freq_ratio)
# print(x.shape)
x = x.permute(0, 1, 3, 2, 4).contiguous()
x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3], x.shape[4])
return x
Expand All @@ -936,7 +947,7 @@ def repeat_wat2img(self, x, cur_pos):
x = x.repeat(repeats=(1, 1, 4, 1))
return x

def forward(self, input: torch.Tensor, infer_mode=False): # out_feat_keys: List[str] = None):
def forward(self, input: torch.Tensor, infer_mode=False, representation='fine_grained_embedding'): # out_feat_keys: List[str] = None):
# x = self.spectrogram_extractor(x) # (batch_size, 1, time_steps, freq_bins)
# x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins)

Expand All @@ -951,7 +962,7 @@ def forward(self, input: torch.Tensor, infer_mode=False): # out_feat_keys: List
x = self.reshape_wav2img(x)
output_dict = self.forward_features(x)
# x = self.head(x)
return output_dict["fine_grained_embedding"]
return output_dict[representation]


if __name__ == '__main__':
Expand Down
24 changes: 17 additions & 7 deletions dataset_preperation/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,28 @@ We introduce an efficient pipeline for collecting ambient audio. It starts by an
For initializing your environment, please refer to the [general README](../README.md).

## Autocap Dataset Download
<!-- - We currently provide the following datasets:
* autocap_audioset_vggsounds: containing roughly **445K** audio-text pairs, derived from VGGSounds and a subset of AudioSet. This dataset was not filtered to remove music and speech.
* AutoReCap-XL: containing around **57M** audio-text pairs, derived from Youtube videos. This dataset contain mainly ambinet audio clips and few speech and music clips. Please refer to the paper for more details on this dataset. -->

**datasets will be coming later!**
- We currently provide the following datasets:
<!-- * **autocap_audioset_vggsounds:** containing roughly **445K** audio-text pairs, derived from VGGSounds and a subset of AudioSet. This dataset was not filtered to remove music and speech. -->
* **AutoReCapXL:** containing more than **47M** audio-text pairs, filtered to have LAION CLAP similaity above 0.1
* **AutoReCapXL-MQ:** containing more than **20.7M** audio-text pairs, filtered to have LAION CLAP similaity above 0.4
* **AutoReCapXL-MQ-L:** containing more than **20.7M** audio-text pairs, filtered to have LAION CLAP similaity above 0.4 and audio clips longer than 5 seconds.
* **AutoReCapXL-HQ:** containing more than **10.7M** audio-text pairs, filtered to have LAION CLAP similaity above 0.5.

AutoReCap datasets are derived from Youtube videos. The datasets contain mainly ambinet audio clips and few speech and music clips. Please refer to the paper for more details on this dataset. These datasets can be filtered based on specified CLAP similarity thresholds and minimum audio clip lengths as described below.

```shell
python download.py --save_dir <path-to-save-dir> --dataset_name <dataset-subset>

# Example
python download.py --save_dir data/datasets/autocap --dataset_name AutoReCapXL-HQ --audio_only

# Example of filtering according to clap similarity and audio clip length
python download.py --save_dir data/datasets/autocap --dataset_name AutoReCapXL --clap_threshold 0.4 --min_audio_len 5 --audio_only

# Example of downloading only a subset of the datasets
python download.py --save_dir data/datasets/autocap --dataset_name AutoReCapXL-HQ --start_idx 0 --end_idx 100000 --audio_only
```
<!-- # Example -->
<!-- python download.py --save_dir data/datasets/autocap --dataset_name autocap_audioset_vggsounds --audio_only -->


By default, the script will download videos along with their metadata.

Expand Down
45 changes: 40 additions & 5 deletions dataset_preperation/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,11 @@ def update_interval_dict(dict_1, dict_2):

def read_video_segments_info(local_input_video_segments,
start_idx=0,
end_idx=int(1e9)):
end_idx=int(1e9),
min_audio_len=0.0,
clap_threshold=0.0):
all_video_segments = {}
total_number_of_clips = 0
with open(local_input_video_segments, 'r') as f:
last_idx = 0
for idx, json_str in enumerate(tqdm(f, desc="parsing json input")):
Expand All @@ -141,8 +144,21 @@ def read_video_segments_info(local_input_video_segments,
json_str = json_str[:-1]
if json_str.endswith(','):
json_str = json_str[:-1]
json_object = json.loads(json_str)
update_interval_dict(all_video_segments, json_object)

data = json.loads(json_str)
video_ids = list(data.keys())
if len(video_ids) == 0:
continue
video_id = video_ids[0]

intervals = data[video_id].get("intervals", [])
len_intervals_filtered = [clip for clip in intervals if float(clip['end']) - float(clip['start'])>= min_audio_len]
clap_len_intervals_filtered = [clip for clip in len_intervals_filtered if clip.get("CLAP_SIM", -9999) is not None and clip.get("CLAP_SIM", -9999) >= clap_threshold]
total_number_of_clips += len(clap_len_intervals_filtered)
video_data = {}
video_data[video_id] = {}
video_data[video_id]['intervals'] = clap_len_intervals_filtered
update_interval_dict(all_video_segments, video_data)
except Exception as e:
print("[ERROR] Couldn't parse json string:", json_str)
continue
Expand All @@ -151,6 +167,7 @@ def read_video_segments_info(local_input_video_segments,
if last_idx >= end_idx:
break

print(f"Found {total_number_of_clips} audio clips.")
return all_video_segments

def download_audioset_split(json_file,
Expand All @@ -163,14 +180,18 @@ def download_audioset_split(json_file,
end_idx=int(1e9),
num_processes=os.cpu_count(),
resume=True,
files_per_folder=5000
files_per_folder=5000,
clap_threshold=0.4,
min_audio_len=4,
):

os.makedirs(save_dir, exist_ok=True)

all_video_segments = read_video_segments_info(json_file,
start_idx=start_idx,
end_idx=end_idx)
end_idx=end_idx,
min_audio_len=min_audio_len,
clap_threshold=clap_threshold)

download_audio_split = partial(download_yt_video,
save_dir=save_dir,
Expand Down Expand Up @@ -208,6 +229,18 @@ def download_audioset_split(json_file,
required=True,
help=f"Provided the dataset names. Available datasets are {dataset_urls.keys()}")

parser.add_argument("--clap_threshold",
type=float,
required=False,
default=0.4,
help=f"Provided the clap similarity threshold to filter the dataset, default: 0.4")

parser.add_argument("--min_audio_len",
type=float,
required=False,
default=4,
help=f"Provided the minimum audio clip length to filter the dataset, default: 4s")

parser.add_argument("--input_file",
type=str,
default=None,
Expand Down Expand Up @@ -269,5 +302,7 @@ def download_audioset_split(json_file,
proxy_port=args.proxy,
start_idx=args.start_idx,
end_idx=args.end_idx,
clap_threshold=args.clap_threshold,
min_audio_len=args.min_audio_len,
resume=not args.redownload,
files_per_folder=args.files_per_folder)
7 changes: 5 additions & 2 deletions dataset_preperation/download_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@
import wget

save_dir = 'data/json_files'
dataset_urls = {}

dataset_urls = {"AutoReCapXL":'https://huggingface.co/datasets/mali6/autocap/resolve/main/processed_snap-hdvila100m-w-clap-videos_segments_filtered_above_01.json',
"AutoReCapXL-MQ": 'https://huggingface.co/datasets/mali6/autocap/resolve/main/processed_snap-hdvila100m-w-clap-videos_segments_filtered_above_04.json',
"AutoReCapXL-MQ-L": 'https://huggingface.co/datasets/mali6/autocap/resolve/main/processed_snap-hdvila100m-w-clap-videos_segments_filtered_above_04_longer_than_5s.json',
"AutoReCapXL-HQ": 'https://huggingface.co/datasets/mali6/autocap/resolve/main/processed_snap-hdvila100m-w-clap-videos_segments_filtered_above_05.json'}


def get_dataset_json_file(dataset_name, dataset_json_file_path=None, download=True):
if dataset_json_file_path is None:
Expand Down

0 comments on commit 51ab220

Please sign in to comment.