-
Notifications
You must be signed in to change notification settings - Fork 2
/
bounds.py
62 lines (44 loc) · 2.15 KB
/
bounds.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
#!/usr/bin/env python3.6
from typing import Callable
import torch
from torch import Tensor
class PreciseBounds():
def __init__(self, **kwargs):
self.margin: float = kwargs['margin']
self.mode: str = kwargs['mode']
self.__fn__ = getattr(__import__('utils'), kwargs['fn'])
print(f"Initialized {self.__class__.__name__} with {kwargs}")
def __call__(self, image: Tensor, target: Tensor, weak_target: Tensor, filename: str) -> Tensor:
value: Tensor = self.__fn__(target[None, ...])[0].type(torch.float32) # cwh and not bcwh
margin: Tensor
if self.mode == "percentage":
margin = value * self.margin
elif self.mode == "abs":
margin = torch.ones_like(value) * self.margin
else:
raise ValueError("mode")
with_margin: Tensor = torch.stack([value - margin, value + margin], dim=-1)
assert with_margin.shape == (*value.shape, 2), with_margin.shape
res = torch.max(with_margin, torch.zeros(*value.shape, 2)).type(torch.float32)
return res
class PredictionBounds():
def __init__(self, **kwargs):
self.margin: float = kwargs['margin']
self.mode: str = kwargs['mode']
# Do it on CPU to avoid annoying the main loop
self.net: Callable[Tensor, [Tensor]] = torch.load(kwargs['net'], map_location='cpu')
print(f"Initialized {self.__class__.__name__} with {kwargs}")
def __call__(self, image: Tensor, target: Tensor, weak_target: Tensor, filename: str) -> Tensor:
with torch.no_grad():
value: Tensor = self.net(image[None, ...])[0].type(torch.float32)[..., None] # cwh and not bcwh
margin: Tensor
if self.mode == "percentage":
margin = value * self.margin
elif self.mode == "abs":
margin = torch.ones_like(value) * self.margin
else:
raise ValueError("mode")
with_margin: Tensor = torch.stack([value - margin, value + margin], dim=-1)
assert with_margin.shape == (*value.shape, 2), with_margin.shape
res = torch.max(with_margin, torch.zeros(*value.shape, 2)).type(torch.float32)
return res