Skip to content

Commit

Permalink
debug
Browse files Browse the repository at this point in the history
  • Loading branch information
HaoZhang534 committed Aug 21, 2023
1 parent 6af0cda commit 4c47542
Show file tree
Hide file tree
Showing 12 changed files with 16 additions and 28 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
.idea
*.iml
out
output
gen
*.out
.gitignore
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ export DATASET=/pth/to/dataset
Download the pretrained checkpoint from [here](https://github.com/IDEA-Research/OpenSeeD/releases/download/openseed/model_state_dict_swint_51.2ap.pt).
### :bulb: Demo script
```sh
python demo/demo_panoseg.py evaluate --conf_files configs/openseed/openseed_swint_lang.yaml --image_path images/your_image.jpg --overrides WEIGHT /path/to/ckpt/model_state_dict_swint_51.2ap.pt
python demo/demo_panoseg.py evaluate --conf_files configs/openseed/openseed_swint_lang.yaml --image_path images/animals.png --overrides WEIGHT /path/to/ckpt/model_state_dict_swint_51.2ap.pt
```
:fire: Remember to **modify the vocabulary** `thing_classes` and `stuff_classes` in `demo_panoseg.py` if your want to segment open-vocabulary objects.

Expand Down
2 changes: 1 addition & 1 deletion datasets/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
COCOPanopticEvaluator,
)
from openseed.utils import configurable
from utils.distributed import get_world_size
from detectron2.utils.comm import get_world_size
from typing import Any, Dict, List, Set

class JointLoader(torchdata.IterableDataset):
Expand Down
2 changes: 1 addition & 1 deletion datasets/evaluation/segmentation_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from detectron2.utils.comm import all_gather, is_main_process
from detectron2.utils.file_io import PathManager
from detectron2.evaluation.evaluator import DatasetEvaluator
from utils.distributed import synchronize
from detectron2.utils.comm import synchronize

from ..semseg_loader import load_semseg

Expand Down
16 changes: 4 additions & 12 deletions demo/demo_instseg.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,3 @@
# --------------------------------------------------------
# X-Decoder -- Generalized Decoding for Pixel, Image, and Language
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Xueyan Zou ([email protected])
# --------------------------------------------------------

import os
import sys
import logging
Expand All @@ -27,7 +20,6 @@
from openseed import build_model
from detectron2.utils.colormap import random_color
from utils.visualizer import Visualizer
from utils.distributed import init_distributed


logger = logging.getLogger(__name__)
Expand All @@ -42,20 +34,20 @@ def main(args=None):
absolute_user_dir = os.path.abspath(cmdline_args.user_dir)
opt['user_dir'] = absolute_user_dir

opt = init_distributed(opt)
# opt = init_distributed(opt)

# META DATA
pretrained_pth = os.path.join(opt['WEIGHT'])
output_root = './output'
image_pth = 'images/owls.jpeg'
image_pth = cmdline_args.image_path

model = BaseModel(opt, build_model(opt)).from_pretrained(pretrained_pth).eval().cuda()

t = []
t.append(transforms.Resize(800, interpolation=Image.BICUBIC))
transform = transforms.Compose(t)

thing_classes = ["owl"]
thing_classes=["zebra","giraffe","tree","ostrich"]
thing_colors = [random_color(rgb=True, maximum=255).astype(np.int).tolist() for _ in range(len(thing_classes))]
thing_dataset_id_to_contiguous_id = {x:x for x in range(len(thing_classes))}

Expand All @@ -65,7 +57,7 @@ def main(args=None):
thing_dataset_id_to_contiguous_id=thing_dataset_id_to_contiguous_id,
)
# model.model.sem_seg_head.predictor.lang_encoder.get_text_embeddings(thing_classes + ["background"], is_eval=False)
model.model.sem_seg_head.predictor.lang_encoder.get_text_embeddings(thing_classes, is_eval=False)
model.model.sem_seg_head.predictor.lang_encoder.get_text_embeddings(thing_classes, is_eval=True)
metadata = MetadataCatalog.get('demo')
model.model.metadata = metadata
model.model.sem_seg_head.num_classes = len(thing_classes)
Expand Down
6 changes: 2 additions & 4 deletions demo/demo_panoseg.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from openseed.BaseModel import BaseModel
from openseed import build_model
from utils.visualizer import Visualizer
from utils.distributed import init_distributed


logger = logging.getLogger(__name__)

Expand All @@ -40,8 +40,6 @@ def main(args=None):
absolute_user_dir = os.path.abspath(cmdline_args.user_dir)
opt['user_dir'] = absolute_user_dir

opt = init_distributed(opt)

# META DATA
pretrained_pth = os.path.join(opt['WEIGHT'])
output_root = './output'
Expand All @@ -68,7 +66,7 @@ def main(args=None):
stuff_classes=stuff_classes,
stuff_dataset_id_to_contiguous_id=stuff_dataset_id_to_contiguous_id,
)
model.model.sem_seg_head.predictor.lang_encoder.get_text_embeddings(thing_classes + stuff_classes + ["background"], is_eval=False)
model.model.sem_seg_head.predictor.lang_encoder.get_text_embeddings(thing_classes + stuff_classes, is_eval=False)
metadata = MetadataCatalog.get('demo')
model.model.metadata = metadata
model.model.sem_seg_head.num_classes = len(thing_classes + stuff_classes)
Expand Down
7 changes: 2 additions & 5 deletions demo/demo_semseg.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from openseed.BaseModel import BaseModel
from openseed import build_model
from utils.visualizer import Visualizer
from utils.distributed import init_distributed


logger = logging.getLogger(__name__)
Expand All @@ -40,7 +39,6 @@ def main(args=None):
if cmdline_args.user_dir:
absolute_user_dir = os.path.abspath(cmdline_args.user_dir)
opt['user_dir'] = absolute_user_dir
opt = init_distributed(opt)

# META DATA
pretrained_pth = os.path.join(opt['WEIGHT'])
Expand All @@ -62,7 +60,7 @@ def main(args=None):
stuff_classes=stuff_classes,
stuff_dataset_id_to_contiguous_id=stuff_dataset_id_to_contiguous_id,
)
model.model.sem_seg_head.predictor.lang_encoder.get_text_embeddings(stuff_classes + ["background"], is_eval=True)
model.model.sem_seg_head.predictor.lang_encoder.get_text_embeddings(stuff_classes, is_eval=True)
metadata = MetadataCatalog.get('demo')
model.model.metadata = metadata
model.model.sem_seg_head.num_classes = len(stuff_classes)
Expand All @@ -77,9 +75,8 @@ def main(args=None):
images = torch.from_numpy(image.copy()).permute(2,0,1).cuda()

batch_inputs = [{'image': images, 'height': height, 'width': width}]
outputs = model.forward(batch_inputs)
outputs = model.forward(batch_inputs,inference_task="sem_seg")
visual = Visualizer(image_ori, metadata=metadata)

sem_seg = outputs[-1]['sem_seg'].max(0)[1]
demo = visual.draw_sem_seg(sem_seg.cpu(), alpha=0.5) # rgb Image

Expand Down
Binary file removed images/owls.jpeg
Binary file not shown.
2 changes: 1 addition & 1 deletion openseed/BaseModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def forward(self, *inputs, **kwargs):
return outputs

def save_pretrained(self, save_dir):
torch.save(self.model.state_dict(), save_path)
torch.save(self.model.state_dict(), save_dir)

def from_pretrained(self, load_dir):
state_dict = torch.load(load_dir, map_location='cpu')
Expand Down
2 changes: 1 addition & 1 deletion openseed/language/LangEncoder/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from timm.models.layers import DropPath, trunc_normal_

from .registry import register_lang_encoder
from utils.distributed import is_main_process
from detectron2.utils.comm import is_main_process
from utils.model import register_norm_module

logger = logging.getLogger(__name__)
Expand Down
2 changes: 1 addition & 1 deletion train_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from detectron2.config import LazyConfig, instantiate

from utils.arguments import load_opt_command
from utils.distributed import init_distributed, is_main_process, apply_distributed, synchronize
from detectron2.utils.comm import get_world_size, is_main_process

# MaskDINO

Expand Down
2 changes: 1 addition & 1 deletion utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import time
import pickle
import torch
from utils.distributed import is_main_process
from detectron2.utils.comm import is_main_process

logger = logging.getLogger(__name__)

Expand Down

0 comments on commit 4c47542

Please sign in to comment.