Skip to content

Commit

Permalink
Merge pull request #30 from 9527-csroad/9527-csroad-patch-1
Browse files Browse the repository at this point in the history
Update segmentation_pipeline.py
  • Loading branch information
wiktorlazarski authored Sep 7, 2023
2 parents cc16033 + d1e4b3b commit c1b4b9b
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 1 deletion.
28 changes: 28 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,34 @@ cd head-segmentation
streamlit run ./scripts/apps/web_checking.py
```

## ⏰ Inference time

If you are strict with time, you can use gpu to acclerate inference. Visualization also consume some time, you can just save the final result as below.

```python
import torch
from PIL import Image
import head_segmentation.segmentation_pipeline as seg_pipeline

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

segmentation_pipeline = seg_pipeline.HumanHeadSegmentationPipeline(device=device)

segmentation_map = segmentation_pipeline.predict(image)

segmented_region = image * cv2.cvtColor(segmentation_map, cv2.COLOR_GRAY2RGB)

pil_image = Image.fromarray(segmented_region)
pil_image.save(save_path)
```

The table below presents inference time which is tested on Tesla T4 (just for reference). The first image will take more time.

| | save figure | just save final result|
|:--------------:|:---------------------:|:---------------------:|
| cpu | around 2.1s | around 0.8s |
| gpu | around 1.4s | around 0.15s |

<div align="center">

### 🤗 Enjoy the model!
Expand Down
3 changes: 2 additions & 1 deletion head_segmentation/_version.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
# __version__ = "1.0.0" # First version
__version__ = "1.1.0" # Segmentation model downloaded from GDrive
# __version__ = "1.1.0" # Segmentation model downloaded from GDrive
__version__ = "1.3.0"
5 changes: 5 additions & 0 deletions head_segmentation/segmentation_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def __init__(
self,
model_path: str = C.HEAD_SEGMENTATION_MODEL_PATH,
model_url: str = C.HEAD_SEGMENTATION_MODEL_URL,
device: torch.device = torch.device('cpu')
):
if not os.path.exists(model_path):
model_path = C.HEAD_SEGMENTATION_MODEL_PATH
Expand All @@ -26,6 +27,7 @@ def __init__(

gdown.download(model_url, model_path, quiet=False)

self.device = device
ckpt = torch.load(model_path, map_location=torch.device("cpu"))
hparams = ckpt["hyper_parameters"]

Expand All @@ -35,14 +37,17 @@ def __init__(
self._model = mdl.HeadSegmentationModel.load_from_checkpoint(
ckpt_path=model_path
)
self._model.to(self.device)
self._model.eval()

def __call__(self, image: np.ndarray) -> np.ndarray:
return self.predict(image)

def predict(self, image: np.ndarray) -> np.ndarray:
preprocessed_image = self._preprocess_image(image)
preprocessed_image = preprocessed_image.to(self.device)
mdl_out = self._model(preprocessed_image)
mdl_out = mdl_out.cpu()
pred_segmap = self._postprocess_model_output(mdl_out, original_image=image)
return pred_segmap

Expand Down

0 comments on commit c1b4b9b

Please sign in to comment.