diff --git a/references/video_classification/train.py b/references/video_classification/train.py index 945c8c67c76..5cfd9a6bf43 100644 --- a/references/video_classification/train.py +++ b/references/video_classification/train.py @@ -164,7 +164,7 @@ def main(args): if args.cache_dataset and os.path.exists(cache_path): print(f"Loading dataset_train from {cache_path}") - dataset, _ = torch.load(cache_path, weights_only=True) + dataset, _ = torch.load(cache_path) dataset.transform = transform_train else: if args.distributed: @@ -201,7 +201,7 @@ def main(args): if args.cache_dataset and os.path.exists(cache_path): print(f"Loading dataset_test from {cache_path}") - dataset_test, _ = torch.load(cache_path, weights_only=True) + dataset_test, _ = torch.load(cache_path) dataset_test.transform = transform_test else: if args.distributed: