Skip to content

Commit

Permalink
Remove to_device from PTEngine (#2260)
Browse files Browse the repository at this point in the history
### Changes

Remove logic to set device in `PTEngine`, to support multi-device model
#2253
  • Loading branch information
AlexanderDokuchaev authored Nov 15, 2023
1 parent b4b2e19 commit 610e800
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from torchvision.models.detection.ssd import SSD
from torchvision.models.detection.ssd import GeneralizedRCNNTransform
from nncf.common.logging.track_progress import track
from functools import partial

ROOT = Path(__file__).parent.resolve()
DATASET_URL = "https://ultralytics.com/assets/coco128.zip"
Expand Down Expand Up @@ -125,10 +126,10 @@ def validate(model: torch.nn.Module, dataset: COCO128Dataset, device: torch.devi
return computed_metrics["map_50"]


def transform_fn(data_item: Tuple[torch.Tensor, Dict]) -> torch.Tensor:
def transform_fn(data_item: Tuple[torch.Tensor, Dict], device: torch.device) -> torch.Tensor:
# Skip label and add a batch dimension to an image tensor
images, _ = data_item
return images[None]
return images[None].to(device)


def main():
Expand All @@ -149,7 +150,7 @@ def main():
disable_tracing(SSD.postprocess_detections)

# Quantize model
calibration_dataset = nncf.Dataset(dataset, transform_fn)
calibration_dataset = nncf.Dataset(dataset, partial(transform_fn, device=device))
quantized_model = nncf.quantize(model, calibration_dataset)

# Convert to OpenVINO
Expand Down
9 changes: 0 additions & 9 deletions nncf/torch/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,6 @@
from torch import nn

from nncf.common.engine import Engine
from nncf.torch.nested_objects_traversal import objwalk
from nncf.torch.utils import get_model_device
from nncf.torch.utils import is_tensor


class PTEngine(Engine):
Expand All @@ -34,7 +31,6 @@ def __init__(self, model: nn.Module):

self._model = model
self._model.eval()
self._device = get_model_device(model)

def infer(
self, input_data: Union[torch.Tensor, Tuple[torch.Tensor], Dict[str, torch.Tensor]]
Expand All @@ -46,11 +42,6 @@ def infer(
:return: Model outputs.
"""

def send_to_device(tensor):
return tensor.to(self._device)

input_data = objwalk(input_data, is_tensor, send_to_device)

if isinstance(input_data, dict):
return self._model(**input_data)
if isinstance(input_data, tuple):
Expand Down

0 comments on commit 610e800

Please sign in to comment.