-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
chore: update models of local backend to use torch, remove tensorflow
- Loading branch information
Showing
8 changed files
with
120 additions
and
89 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,4 @@ | ||
from . import image | ||
from . import video | ||
# from . import image | ||
# from . import video | ||
|
||
# __all__ = ["image", "video"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,73 +1,117 @@ | ||
import tensorflow as tf | ||
import cv2 | ||
import numpy as np | ||
from transformers import TFAutoModel, AutoConfig | ||
from tensorflow.keras.applications import ( | ||
EfficientNetV2B0, | ||
ResNet50, | ||
preprocess_input as keras_preprocess_input, | ||
) | ||
import torch | ||
import torchvision.transforms as transforms | ||
from torchvision.models import resnet50, efficientnet_v2_s | ||
|
||
from transformers import AutoModel, AutoFeatureExtractor | ||
|
||
# from tensorflow.keras.applications import ( | ||
# EfficientNetV2B0, | ||
# ResNet50, | ||
# ) | ||
|
||
|
||
# EmbeddingExtractor Class with model name as a parameter | ||
class EmbeddingExtractor: | ||
def __init__(self, model_name="EfficientNetV2B0"): | ||
self.model_name = model_name | ||
self.model, self.preprocess_fn = self.load_model() | ||
self.use_pretrained_model = False | ||
self.feature_extractor = None | ||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
|
||
self.load_model() | ||
|
||
def load_model(self): | ||
if self.model_name == "EfficientNetV2B0": | ||
base_model = EfficientNetV2B0(include_top=False, pooling="avg") | ||
preprocess_fn = ( | ||
keras_preprocess_input # Define custom preprocessing if needed | ||
from torchvision.models import EfficientNet_V2_S_Weights | ||
base_model = efficientnet_v2_s(weights=EfficientNet_V2_S_Weights.DEFAULT) | ||
self.model = torch.nn.Sequential(*list(base_model.children())[:-1]).to( | ||
self.device | ||
) | ||
elif self.model_name == "ResNet50": | ||
base_model = ResNet50(include_top=False, pooling="avg") | ||
preprocess_fn = ( | ||
keras_preprocess_input # Define custom preprocessing if needed | ||
from torchvision.models import ResNet50_Weights | ||
base_model = resnet50(weights=ResNet50_Weights.DEFAULT) | ||
self.model = torch.nn.Sequential(*list(base_model.children())[:-1]).to( | ||
self.device | ||
) | ||
else: | ||
config = AutoConfig.from_pretrained(self.model_name) | ||
base_model = TFAutoModel.from_pretrained(self.model_name, config=config) | ||
preprocess_fn = ( | ||
keras_preprocess_input # Define custom preprocessing if needed | ||
self.use_pretrained_model = True | ||
self.feature_extractor = AutoFeatureExtractor.from_pretrained( | ||
self.model_name | ||
) | ||
|
||
return ( | ||
tf.keras.Model(inputs=base_model.input, outputs=base_model.output), | ||
preprocess_fn, | ||
) | ||
self.model = AutoModel.from_pretrained(self.model_name).to(self.device) | ||
|
||
def preprocess_image(self, image): | ||
image = cv2.resize(image, (224, 224)) | ||
image = image.astype("float32") | ||
image = self.preprocess_fn(image) | ||
return image | ||
if self.use_pretrained_model: | ||
# If using Hugging Face models, do not apply torchvision transforms | ||
if image.shape[-1] == 3: # Ensure 3 channels for RGB | ||
image = cv2.resize(image, (2048, 2048)) # Adjust size as needed | ||
image = np.expand_dims(image, axis=0) # Add batch dimension | ||
return image | ||
else: | ||
raise ValueError( | ||
f"Expected 3 channels (RGB) but got {image.shape[-1]} channels." | ||
) | ||
else: | ||
transform = transforms.Compose( | ||
[ | ||
transforms.ToTensor(), | ||
transforms.Resize((2048, 2048)), # Adjust size as needed | ||
transforms.Normalize( | ||
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] | ||
), | ||
] | ||
) | ||
image = transform(image) | ||
image = image.unsqueeze(0).to(self.device) | ||
return image | ||
|
||
def extract_image_embedding(self, image_path): | ||
image = cv2.imread(image_path) | ||
if image is not None: | ||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | ||
image = self.preprocess_image(image) | ||
embedding = self.model.predict(np.expand_dims(image, axis=0)) | ||
return embedding.squeeze() | ||
with torch.no_grad(): | ||
if self.use_pretrained_model: | ||
inputs = self.feature_extractor( | ||
images=image, return_tensors="pt" | ||
).to(self.device) | ||
embedding = self.model(**inputs).last_hidden_state | ||
return embedding.cpu().numpy().squeeze() | ||
else: | ||
embedding = self.model(image).squeeze() | ||
return embedding.cpu().numpy() | ||
return None | ||
|
||
def process_image(self, image_id, file_path): | ||
def process_image(self, file_path): | ||
embedding = self.extract_image_embedding(file_path) | ||
if embedding is not None: | ||
print(f"Extracted embedding for image {image_id}: {embedding.shape}") | ||
print(f"Extracted embedding for image {file_path}: {embedding.shape}") | ||
else: | ||
print(f"Failed to extract embedding for image {image_id}") | ||
print(f"Failed to extract embedding for image {file_path}") | ||
return embedding | ||
|
||
|
||
if __name__ == "__main__": | ||
image_path = "path_to_your_image.jpg" | ||
import os | ||
import sys | ||
|
||
# Pass the model name as a parameter | ||
model_name = "microsoft/resnet-50" # Example for Hugging Face model | ||
extractor = EmbeddingExtractor(model_name=model_name) | ||
if len(sys.argv) < 2: | ||
print(f"Usage: {os.path.basename(__file__)} <filepath>") | ||
sys.exit(1) | ||
|
||
image_path = sys.argv[1] | ||
|
||
# Process the image | ||
extractor.process_image(1, image_path) | ||
def test_model(model_name): | ||
extractor = EmbeddingExtractor(model_name=model_name) | ||
extractor.process_image(image_path) | ||
|
||
# Pass the model name as a parameter | ||
for model_name in [ | ||
"EfficientNetV2B0", ## embedding.shape == (1280,) | ||
"ResNet50", ## embedding.shape == (2048,) | ||
"microsoft/resnet-50", # Example for Hugging Face model | ||
]: | ||
print(f"Testing model: {model_name}") | ||
test_model(model_name) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters