forked from purnasai/Dino_V2
-
Notifications
You must be signed in to change notification settings - Fork 0
/
4.0.generate_features_Dino.py
104 lines (87 loc) · 3.55 KB
/
4.0.generate_features_Dino.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
"""
Author:Purnasai
Description:This file generates image features from
Database of images & stores them h5py file.
"""
import os
import random
import h5py
import numpy as np
from tqdm import tqdm
from PIL import Image
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision import transforms
import ssl
ssl._create_default_https_context = ssl._create_unverified_context
dinov2_vitl14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14')
def get_labels(files):
"""
This function takes a list of file paths and returns a list of unique labels extracted from the
directory names in the file paths.
:param files: a list of file paths (strings) that include the directory and filename, separated
by backslashes ("\") on Windows or forward slashes ("/") on Unix-based systems
:return: a list of unique labels extracted from the file paths provided in the `files` parameter.
"""
labels = []
for file_path in files:
directory, _ = file_path.split("\\")
directory_parts = directory.split("/")
label = directory_parts[-1]
if label not in labels:
labels.append(label)
return labels
def list_files(dataset_path):
"""
This function returns a list of all files in a directory and its subdirectories.
:param dir: The directory path where you want to list all the files
:return: The function `list_files` returns a list of file paths for all the files in the directory
and its subdirectories.
"""
images = []
for root, _, files in os.walk(dataset_path):
for name in files:
images.append(os.path.join(root, name))
return images
class CustomImageDataset(Dataset):
"""The above class is a custom dataset class for images in PyTorch."""
def __init__(self, img_dir):
self.img_dir = img_dir
self.images = list_files(self.img_dir)
random.choices(self.images, k=5)
self.transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])
])
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
img_path = self.images[idx]
image = Image.open(img_path).convert("RGB")
if self.transform:
image = self.transform(image)
return image, img_path
dir_path = "./Data/"
dataset = CustomImageDataset(dir_path)
train_dataloader = DataLoader(dataset, batch_size=16, shuffle=True,)
final_img_features = []
final_img_filepaths = []
for image_tensors, file_paths in tqdm(train_dataloader):
try:
image_features = dinov2_vitl14(image_tensors) #384 small, #768 base, #1024 large
image_features /= image_features.norm(dim=-1, keepdim=True)
image_features = image_features.tolist()
final_img_features.extend(image_features)
final_img_filepaths.extend((list(file_paths)))
except Exception as e:
print("Exception occurred: ",e)
break
with h5py.File('features/image_features_dino.h5','w') as h5f:
h5f.create_dataset("image_features", data= np.array(final_img_features))
# to save file names strings in byte format.
h5f.create_dataset("image_filenames", data= np.array(final_img_filepaths,
dtype=object))