diff --git a/data_juicer/ops/base_op.py b/data_juicer/ops/base_op.py index b5be4fda7..5f602e165 100644 --- a/data_juicer/ops/base_op.py +++ b/data_juicer/ops/base_op.py @@ -60,6 +60,7 @@ def wrapper(samples, *args, **kwargs): logger.error( f'An error occurred in mapper operation when processing ' f'samples {samples}, {type(e)}: {e}') + traceback.print_exc() ret = {key: [] for key in samples.keys()} ret[Fields.stats] = [] ret[Fields.source_file] = [] @@ -97,6 +98,7 @@ def wrapper(sample, *args, **kwargs): logger.error( f'An error occurred in mapper operation when processing ' f'sample {sample}, {type(e)}: {e}') + traceback.print_exc() ret = {key: [] for key in sample.keys()} ret[Fields.stats] = [] ret[Fields.source_file] = [] diff --git a/data_juicer/ops/deduplicator/ray_video_deduplicator.py b/data_juicer/ops/deduplicator/ray_video_deduplicator.py index 2f90a6fed..7193e9313 100644 --- a/data_juicer/ops/deduplicator/ray_video_deduplicator.py +++ b/data_juicer/ops/deduplicator/ray_video_deduplicator.py @@ -2,7 +2,8 @@ from jsonargparse.typing import PositiveInt -from data_juicer.utils.mm_utils import load_data_with_context, load_video +from data_juicer.utils.mm_utils import (close_video, load_data_with_context, + load_video) from ..base_op import OPERATORS from ..op_fusion import LOADED_VIDEOS @@ -52,4 +53,7 @@ def calculate_hash(self, sample, context=False): if packet.stream.type == 'video': md5_hash.update(bytes(packet)) + for key in videos: + close_video(videos[key]) + return md5_hash.hexdigest() diff --git a/data_juicer/ops/deduplicator/video_deduplicator.py b/data_juicer/ops/deduplicator/video_deduplicator.py index 17b84c0ba..ed5a767a4 100644 --- a/data_juicer/ops/deduplicator/video_deduplicator.py +++ b/data_juicer/ops/deduplicator/video_deduplicator.py @@ -3,7 +3,8 @@ from typing import Dict, Set, Tuple from data_juicer.utils.constant import HashKeys -from data_juicer.utils.mm_utils import load_data_with_context, load_video +from data_juicer.utils.mm_utils import (close_video, load_data_with_context, + load_video) from ..base_op import OPERATORS, Deduplicator from ..op_fusion import LOADED_VIDEOS @@ -61,6 +62,9 @@ def compute_hash(self, sample, context=False): if packet.stream.type == 'video': md5_hash.update(bytes(packet)) + for key in videos: + close_video(videos[key]) + sample[HashKeys.videohash] = md5_hash.hexdigest() return sample diff --git a/data_juicer/ops/filter/video_aesthetics_filter.py b/data_juicer/ops/filter/video_aesthetics_filter.py index 03610396b..3d030b170 100644 --- a/data_juicer/ops/filter/video_aesthetics_filter.py +++ b/data_juicer/ops/filter/video_aesthetics_filter.py @@ -4,7 +4,7 @@ from data_juicer.utils.availability_utils import AvailabilityChecking from data_juicer.utils.constant import Fields, StatsKeys -from data_juicer.utils.mm_utils import (extract_key_frames, +from data_juicer.utils.mm_utils import (close_video, extract_key_frames, extract_video_frames_uniformly, load_data_with_context, load_video) @@ -181,7 +181,7 @@ def compute_stats(self, sample, rank=None, context=False): if not context: for vid_key in videos: - videos[vid_key].close() + close_video(videos[vid_key]) return sample diff --git a/data_juicer/ops/filter/video_aspect_ratio_filter.py b/data_juicer/ops/filter/video_aspect_ratio_filter.py index 8d1e654a2..49f684ebd 100644 --- a/data_juicer/ops/filter/video_aspect_ratio_filter.py +++ b/data_juicer/ops/filter/video_aspect_ratio_filter.py @@ -3,7 +3,8 @@ import numpy as np from data_juicer.utils.constant import Fields, StatsKeys -from data_juicer.utils.mm_utils import load_data_with_context, load_video +from data_juicer.utils.mm_utils import (close_video, load_data_with_context, + load_video) from ..base_op import OPERATORS, Filter from ..op_fusion import LOADED_VIDEOS @@ -67,7 +68,7 @@ def compute_stats(self, sample, context=False): video_aspect_ratios[ key] = stream.codec_context.width / stream.codec_context.height if not context: - video.close() + close_video(video) sample[Fields.stats][StatsKeys.video_aspect_ratios] = [ video_aspect_ratios[key] for key in loaded_video_keys diff --git a/data_juicer/ops/filter/video_duration_filter.py b/data_juicer/ops/filter/video_duration_filter.py index e65a05c65..9d9653332 100644 --- a/data_juicer/ops/filter/video_duration_filter.py +++ b/data_juicer/ops/filter/video_duration_filter.py @@ -4,7 +4,8 @@ from jsonargparse.typing import NonNegativeInt from data_juicer.utils.constant import Fields, StatsKeys -from data_juicer.utils.mm_utils import load_data_with_context, load_video +from data_juicer.utils.mm_utils import (close_video, load_data_with_context, + load_video) from ..base_op import OPERATORS, Filter from ..op_fusion import LOADED_VIDEOS @@ -68,7 +69,7 @@ def compute_stats(self, sample, context=False): video_durations[video_key] = round(stream.duration * stream.time_base) if not context: - video.close() + close_video(video) # get video durations sample[Fields.stats][StatsKeys.video_duration] = [ diff --git a/data_juicer/ops/filter/video_frames_text_similarity_filter.py b/data_juicer/ops/filter/video_frames_text_similarity_filter.py index 8c3f65d5d..58600eeac 100644 --- a/data_juicer/ops/filter/video_frames_text_similarity_filter.py +++ b/data_juicer/ops/filter/video_frames_text_similarity_filter.py @@ -4,7 +4,8 @@ from data_juicer.utils.availability_utils import AvailabilityChecking from data_juicer.utils.constant import Fields, StatsKeys -from data_juicer.utils.mm_utils import (SpecialTokens, extract_key_frames, +from data_juicer.utils.mm_utils import (SpecialTokens, close_video, + extract_key_frames, extract_video_frames_uniformly, load_data_with_context, load_video, remove_special_tokens) @@ -195,7 +196,7 @@ def compute_stats(self, sample, rank=None, context=False): if not context: for vid_key in videos: - videos[vid_key].close() + close_video(videos[vid_key]) return sample diff --git a/data_juicer/ops/filter/video_nsfw_filter.py b/data_juicer/ops/filter/video_nsfw_filter.py index 380f657c5..c392b44d1 100644 --- a/data_juicer/ops/filter/video_nsfw_filter.py +++ b/data_juicer/ops/filter/video_nsfw_filter.py @@ -3,7 +3,7 @@ from data_juicer.utils.availability_utils import AvailabilityChecking from data_juicer.utils.constant import Fields, StatsKeys -from data_juicer.utils.mm_utils import (extract_key_frames, +from data_juicer.utils.mm_utils import (close_video, extract_key_frames, extract_video_frames_uniformly, load_data_with_context, load_video) from data_juicer.utils.model_utils import get_model, prepare_model @@ -156,7 +156,7 @@ def compute_stats(self, sample, rank=None, context=False): if not context: for vid_key in videos: - videos[vid_key].close() + close_video(videos[vid_key]) return sample diff --git a/data_juicer/ops/filter/video_ocr_area_ratio_filter.py b/data_juicer/ops/filter/video_ocr_area_ratio_filter.py index 2a647541d..cbece9331 100644 --- a/data_juicer/ops/filter/video_ocr_area_ratio_filter.py +++ b/data_juicer/ops/filter/video_ocr_area_ratio_filter.py @@ -6,7 +6,8 @@ from data_juicer import cuda_device_count from data_juicer.utils.availability_utils import AvailabilityChecking from data_juicer.utils.constant import Fields, StatsKeys -from data_juicer.utils.mm_utils import (extract_video_frames_uniformly, +from data_juicer.utils.mm_utils import (close_video, + extract_video_frames_uniformly, load_data_with_context, load_video) from ..base_op import OPERATORS, UNFORKABLE, Filter @@ -171,7 +172,7 @@ def compute_stats(self, sample, rank=None, context=False): video_ocr_area_ratios[video_key] = np.mean(frame_ocr_area_ratios) if not context: - container.close() + close_video(container) # get video durations sample[Fields.stats][StatsKeys.video_ocr_area_ratio] = [ diff --git a/data_juicer/ops/filter/video_resolution_filter.py b/data_juicer/ops/filter/video_resolution_filter.py index cacf97b0a..f87aae4ca 100644 --- a/data_juicer/ops/filter/video_resolution_filter.py +++ b/data_juicer/ops/filter/video_resolution_filter.py @@ -4,7 +4,8 @@ from jsonargparse.typing import PositiveInt from data_juicer.utils.constant import Fields, StatsKeys -from data_juicer.utils.mm_utils import load_data_with_context, load_video +from data_juicer.utils.mm_utils import (close_video, load_data_with_context, + load_video) from ..base_op import OPERATORS, Filter from ..op_fusion import LOADED_VIDEOS @@ -91,7 +92,7 @@ def compute_stats(self, sample, context=False): if not context: for vid_key in videos: - videos[vid_key].close() + close_video(videos[vid_key]) return sample diff --git a/data_juicer/ops/filter/video_watermark_filter.py b/data_juicer/ops/filter/video_watermark_filter.py index e0a1c1e04..2deee0eaf 100644 --- a/data_juicer/ops/filter/video_watermark_filter.py +++ b/data_juicer/ops/filter/video_watermark_filter.py @@ -3,7 +3,7 @@ from data_juicer.utils.availability_utils import AvailabilityChecking from data_juicer.utils.constant import Fields, StatsKeys -from data_juicer.utils.mm_utils import (extract_key_frames, +from data_juicer.utils.mm_utils import (close_video, extract_key_frames, extract_video_frames_uniformly, load_data_with_context, load_video) from data_juicer.utils.model_utils import get_model, prepare_model @@ -157,7 +157,7 @@ def compute_stats(self, sample, rank=None, context=False): if not context: for vid_key in videos: - videos[vid_key].close() + close_video(videos[vid_key]) return sample diff --git a/data_juicer/ops/mapper/video_captioning_from_frames_mapper.py b/data_juicer/ops/mapper/video_captioning_from_frames_mapper.py index 8cbbb4e66..4e1e5d28e 100644 --- a/data_juicer/ops/mapper/video_captioning_from_frames_mapper.py +++ b/data_juicer/ops/mapper/video_captioning_from_frames_mapper.py @@ -1,3 +1,4 @@ +# yapf: disable import copy import random @@ -8,7 +9,8 @@ from data_juicer.utils.availability_utils import AvailabilityChecking from data_juicer.utils.constant import HashKeys -from data_juicer.utils.mm_utils import (SpecialTokens, extract_key_frames, +from data_juicer.utils.mm_utils import (SpecialTokens, close_video, + extract_key_frames, extract_video_frames_uniformly, insert_texts_after_placeholders, load_data_with_context, load_video, @@ -285,7 +287,7 @@ def _process_single_sample(self, ori_sample, rank=None, context=False): if not context: for vid_key in videos: - videos[vid_key].close() + close_video(videos[vid_key]) return generated_samples def _reduce_captions(self, chunk, generated_text_candidates_single_chunk): diff --git a/data_juicer/ops/mapper/video_captioning_from_video_mapper.py b/data_juicer/ops/mapper/video_captioning_from_video_mapper.py index aa56b623f..dd5eca05f 100644 --- a/data_juicer/ops/mapper/video_captioning_from_video_mapper.py +++ b/data_juicer/ops/mapper/video_captioning_from_video_mapper.py @@ -1,3 +1,4 @@ +# yapf: disable import copy import random @@ -8,7 +9,8 @@ from data_juicer.utils.availability_utils import AvailabilityChecking from data_juicer.utils.constant import HashKeys -from data_juicer.utils.mm_utils import (SpecialTokens, extract_key_frames, +from data_juicer.utils.mm_utils import (SpecialTokens, close_video, + extract_key_frames, extract_video_frames_uniformly, insert_texts_after_placeholders, load_data_with_context, load_video, @@ -292,7 +294,7 @@ def _process_single_sample(self, ori_sample, rank=None, context=False): if not context: for vid_key in videos: - videos[vid_key].close() + close_video(videos[vid_key]) return generated_samples def _reduce_captions(self, chunk, generated_text_candidates_single_chunk): diff --git a/data_juicer/ops/mapper/video_face_blur_mapper.py b/data_juicer/ops/mapper/video_face_blur_mapper.py index 05de74cd6..3ffac112b 100644 --- a/data_juicer/ops/mapper/video_face_blur_mapper.py +++ b/data_juicer/ops/mapper/video_face_blur_mapper.py @@ -3,8 +3,9 @@ from data_juicer.utils.availability_utils import AvailabilityChecking from data_juicer.utils.constant import Fields from data_juicer.utils.file_utils import transfer_filename -from data_juicer.utils.mm_utils import (load_data_with_context, load_video, - pil_to_opencv, process_each_frame) +from data_juicer.utils.mm_utils import (close_video, load_data_with_context, + load_video, pil_to_opencv, + process_each_frame) from ..base_op import OPERATORS, Mapper from ..op_fusion import LOADED_VIDEOS @@ -93,7 +94,7 @@ def process(self, sample, context=False): processed_video_keys[video_key] = output_video_key if not context: - video.close() + close_video(video) # when the file is modified, its source file needs to be updated. for i, value in enumerate(loaded_video_keys): diff --git a/data_juicer/ops/mapper/video_remove_watermark_mapper.py b/data_juicer/ops/mapper/video_remove_watermark_mapper.py index 316c47223..f99929439 100644 --- a/data_juicer/ops/mapper/video_remove_watermark_mapper.py +++ b/data_juicer/ops/mapper/video_remove_watermark_mapper.py @@ -8,7 +8,8 @@ from data_juicer.utils.constant import Fields from data_juicer.utils.file_utils import transfer_filename from data_juicer.utils.logger_utils import HiddenPrints -from data_juicer.utils.mm_utils import (extract_video_frames_uniformly, +from data_juicer.utils.mm_utils import (close_video, + extract_video_frames_uniformly, load_data_with_context, load_video, parse_string_to_roi, process_each_frame) @@ -233,7 +234,7 @@ def process_frame_func(frame): if not context: for vid_key in videos: - videos[vid_key].close() + close_video(videos[vid_key]) # when the file is modified, its source file needs to be updated. for i, value in enumerate(sample[self.video_key]): diff --git a/data_juicer/ops/mapper/video_resize_aspect_ratio_mapper.py b/data_juicer/ops/mapper/video_resize_aspect_ratio_mapper.py index 20b969438..e226c2651 100644 --- a/data_juicer/ops/mapper/video_resize_aspect_ratio_mapper.py +++ b/data_juicer/ops/mapper/video_resize_aspect_ratio_mapper.py @@ -6,7 +6,7 @@ from data_juicer.utils.constant import Fields from data_juicer.utils.file_utils import transfer_filename from data_juicer.utils.logger_utils import HiddenPrints -from data_juicer.utils.mm_utils import load_video +from data_juicer.utils.mm_utils import close_video, load_video from ..base_op import OPERATORS, Mapper @@ -117,7 +117,7 @@ def process(self, sample): original_width = video.codec_context.width original_height = video.codec_context.height original_aspect_ratio = Fraction(original_width, original_height) - container.close() + close_video(container) if (original_aspect_ratio >= self.min_ratio and original_aspect_ratio <= self.max_ratio): diff --git a/data_juicer/ops/mapper/video_resize_resolution_mapper.py b/data_juicer/ops/mapper/video_resize_resolution_mapper.py index 03f72f914..a88d3758d 100644 --- a/data_juicer/ops/mapper/video_resize_resolution_mapper.py +++ b/data_juicer/ops/mapper/video_resize_resolution_mapper.py @@ -8,7 +8,7 @@ from data_juicer.utils.constant import Fields from data_juicer.utils.file_utils import transfer_filename from data_juicer.utils.logger_utils import HiddenPrints -from data_juicer.utils.mm_utils import load_video +from data_juicer.utils.mm_utils import close_video, load_video from ..base_op import OPERATORS, Mapper from ..op_fusion import LOADED_VIDEOS @@ -102,7 +102,7 @@ def process(self, sample, context=False): width = video.codec_context.width height = video.codec_context.height origin_ratio = width / height - container.close() + close_video(container) if width >= self.min_width and width <= self.max_width and \ height >= self.min_height and height <= self.max_height: diff --git a/data_juicer/ops/mapper/video_split_by_duration_mapper.py b/data_juicer/ops/mapper/video_split_by_duration_mapper.py index d7626a54e..0a41d3240 100644 --- a/data_juicer/ops/mapper/video_split_by_duration_mapper.py +++ b/data_juicer/ops/mapper/video_split_by_duration_mapper.py @@ -6,7 +6,8 @@ from data_juicer.utils.constant import Fields from data_juicer.utils.file_utils import (add_suffix_to_filename, transfer_filename) -from data_juicer.utils.mm_utils import (SpecialTokens, cut_video_by_seconds, +from data_juicer.utils.mm_utils import (SpecialTokens, close_video, + cut_video_by_seconds, get_video_duration, load_video) from ..base_op import OPERATORS, Mapper @@ -123,7 +124,7 @@ def _process_single_sample(self, sample): video = videos[video_key] new_video_keys = self.split_videos_by_duration( video_key, video) - video.close() + close_video(video) split_video_keys.extend(new_video_keys) place_holders.append(SpecialTokens.video * len(new_video_keys)) diff --git a/data_juicer/ops/mapper/video_split_by_key_frame_mapper.py b/data_juicer/ops/mapper/video_split_by_key_frame_mapper.py index 4a8d276aa..0a4a7c593 100644 --- a/data_juicer/ops/mapper/video_split_by_key_frame_mapper.py +++ b/data_juicer/ops/mapper/video_split_by_key_frame_mapper.py @@ -4,7 +4,8 @@ from data_juicer.utils.constant import Fields from data_juicer.utils.file_utils import (add_suffix_to_filename, transfer_filename) -from data_juicer.utils.mm_utils import (SpecialTokens, cut_video_by_seconds, +from data_juicer.utils.mm_utils import (SpecialTokens, close_video, + cut_video_by_seconds, get_key_frame_seconds, load_video) from ..base_op import OPERATORS, Mapper @@ -105,7 +106,7 @@ def _process_single_sample(self, sample): video_count]: video = videos[video_key] new_video_keys = self.get_split_key_frame(video_key, video) - video.close() + close_video(video) split_video_keys.extend(new_video_keys) place_holders.append(SpecialTokens.video * len(new_video_keys)) diff --git a/data_juicer/ops/mapper/video_tagging_from_frames_mapper.py b/data_juicer/ops/mapper/video_tagging_from_frames_mapper.py index 2fdb68e49..4ac0944c8 100644 --- a/data_juicer/ops/mapper/video_tagging_from_frames_mapper.py +++ b/data_juicer/ops/mapper/video_tagging_from_frames_mapper.py @@ -4,7 +4,7 @@ from data_juicer.utils.availability_utils import AvailabilityChecking from data_juicer.utils.constant import Fields -from data_juicer.utils.mm_utils import (extract_key_frames, +from data_juicer.utils.mm_utils import (close_video, extract_key_frames, extract_video_frames_uniformly, load_data_with_context, load_video) from data_juicer.utils.model_utils import get_model, prepare_model @@ -110,5 +110,10 @@ def process(self, sample, rank=None, context=False): word_count = Counter(words) sorted_word_list = [item for item, _ in word_count.most_common()] video_tags.append(sorted_word_list) + + if not context: + for vid_key in videos: + close_video(videos[vid_key]) + sample[Fields.video_frame_tags] = video_tags return sample diff --git a/data_juicer/ops/op_fusion.py b/data_juicer/ops/op_fusion.py index 5953f1e53..dfc37d45d 100644 --- a/data_juicer/ops/op_fusion.py +++ b/data_juicer/ops/op_fusion.py @@ -160,6 +160,7 @@ def compute_stats(self, sample, rank=None): for context_key in sample[Fields.context]: if isinstance(sample[Fields.context][context_key], av.container.InputContainer): + sample[Fields.context][context_key].streams.video[0].close() sample[Fields.context][context_key].close() _ = sample.pop(Fields.context) return sample diff --git a/data_juicer/utils/mm_utils.py b/data_juicer/utils/mm_utils.py index 3cdb90220..5e0e9d4e7 100644 --- a/data_juicer/utils/mm_utils.py +++ b/data_juicer/utils/mm_utils.py @@ -319,8 +319,8 @@ def cut_video_by_seconds( # close the output videos if isinstance(input_video, str): - container.close() - output_container.close() + close_video(container) + close_video(output_container) if not os.path.exists(output_video): logger.warning(f'This video could not be successfully cut in ' f'[{start_seconds}, {end_seconds}] seconds. ' @@ -384,8 +384,8 @@ def process_each_frame(input_video: Union[str, av.container.InputContainer], # close the output videos if isinstance(input_video, str): - container.close() - output_container.close() + close_video(container) + close_video(output_container) if frame_modified: return output_video @@ -433,7 +433,7 @@ def extract_key_frames(input_video: Union[str, av.container.InputContainer]): break if isinstance(input_video, str): - container.close() + close_video(container) return key_frames @@ -559,7 +559,7 @@ def extract_video_frames_uniformly( # if the container is opened in this function, close it if isinstance(input_video, str): - container.close() + close_video(container) return extracted_frames @@ -663,9 +663,9 @@ def extract_audio_from_video( output_container.mux(packet) if isinstance(input_video, str): - input_container.close() + close_video(input_container) if output_audio: - output_container.close() + close_video(output_container) audio_data_list.append(np.concatenate(audio_data)) return audio_data_list, audio_sampling_rate_list, valid_stream_indexes @@ -786,3 +786,14 @@ def parse_string_to_roi(roi_string, roi_type='pixel'): '"[x1, y1, x2, y2]".') return None return None + + +def close_video(container): + """ + Close the video stream and container to avoid memory leak. + + :param container: the video container. + """ + for video_stream in container.streams.video: + video_stream.close(strict=False) + container.close() diff --git a/tests/ops/mapper/test_video_ffmpeg_wrapped_mapper.py b/tests/ops/mapper/test_video_ffmpeg_wrapped_mapper.py index 53e39b820..44df6dd88 100644 --- a/tests/ops/mapper/test_video_ffmpeg_wrapped_mapper.py +++ b/tests/ops/mapper/test_video_ffmpeg_wrapped_mapper.py @@ -5,7 +5,7 @@ from data_juicer.ops.mapper.video_ffmpeg_wrapped_mapper import \ VideoFFmpegWrappedMapper -from data_juicer.utils.mm_utils import load_video +from data_juicer.utils.mm_utils import close_video, load_video from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase @@ -31,7 +31,7 @@ def get_size(dataset): width = video.streams.video[0].codec_context.width height = video.streams.video[0].codec_context.height sample_list.append((width, height)) - video.close() + close_video(video) sizes.append(sample_list) return sizes diff --git a/tests/ops/mapper/test_video_resize_aspect_ratio_mapper.py b/tests/ops/mapper/test_video_resize_aspect_ratio_mapper.py index 2cd09e86e..406e530fb 100644 --- a/tests/ops/mapper/test_video_resize_aspect_ratio_mapper.py +++ b/tests/ops/mapper/test_video_resize_aspect_ratio_mapper.py @@ -5,7 +5,7 @@ from data_juicer.ops.mapper.video_resize_aspect_ratio_mapper import \ VideoResizeAspectRatioMapper -from data_juicer.utils.mm_utils import load_video +from data_juicer.utils.mm_utils import close_video, load_video from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase @@ -30,7 +30,7 @@ def get_size(dataset): width = video.streams.video[0].codec_context.width height = video.streams.video[0].codec_context.height sample_list.append((width, height)) - video.close() + close_video(video) sizes.append(sample_list) return sizes diff --git a/tools/mm_eval/inception_metrics/dataset.py b/tools/mm_eval/inception_metrics/dataset.py index 15604763d..d4bd96019 100644 --- a/tools/mm_eval/inception_metrics/dataset.py +++ b/tools/mm_eval/inception_metrics/dataset.py @@ -11,6 +11,8 @@ import torch from torch.utils.data import Dataset +from data_juicer.utils.mm_utils import load_video, close_video + @dataclass class VideoDataset(Dataset): dataset_path: str @@ -38,7 +40,7 @@ def __post_init__(self): self.video_paths.append(video_path) def sample_frames(self, video_path): - container = av.open(video_path) + container = load_video(video_path) input_video_stream = container.streams.video[0] total_frame_num = input_video_stream.frames @@ -69,7 +71,7 @@ def sample_frames(self, video_path): sampled_frames.append(tensor_frame) frame_id += 1 - container.close() + close_video(container) assert frame_id >= total_frame_num, 'frame num error' return sampled_frames, spacing @@ -111,16 +113,16 @@ def __post_init__(self): for video_path in data[self.video_key]: if self.mm_dir is not None: video_path = os.path.join(self.mm_dir, video_path) - container = av.open(video_path) + container = load_video(video_path) input_video_stream = container.streams.video[0] total_frame_num = input_video_stream.frames num_samples_from_source = total_frame_num - self.seq_length + 1 for start_frame in range(0, num_samples_from_source): self.start_frames.append((video_path, start_frame, num_samples_from_source)) - container.close() + close_video(container) def read_frames(self, video_path, start_index): - container = av.open(video_path) + container = load_video(video_path) input_video_stream = container.streams.video[0] sampled_idxs = set(range(start_index, start_index + self.seq_length)) @@ -138,7 +140,7 @@ def read_frames(self, video_path, start_index): sampled_frames.append(tensor_frame) frame_id += 1 - container.close() + close_video(container) return sampled_frames def __getitem__(self, index: int) -> dict: