From f70241ba0aa05c7f44589bdb833c2259549b1593 Mon Sep 17 00:00:00 2001 From: dfan Date: Mon, 14 Sep 2020 13:21:59 -0700 Subject: [PATCH] Add device to VGGish and send audio tensor to device --- torchvggish/vggish.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/torchvggish/vggish.py b/torchvggish/vggish.py index ea693a4..13d9a23 100644 --- a/torchvggish/vggish.py +++ b/torchvggish/vggish.py @@ -141,12 +141,15 @@ def _vgg(): class VGGish(VGG): - def __init__(self, urls, pretrained=True, preprocess=True, postprocess=True, progress=True): + def __init__(self, urls, device=None, pretrained=True, preprocess=True, postprocess=True, progress=True): super().__init__(make_layers()) if pretrained: state_dict = hub.load_state_dict_from_url(urls['vggish'], progress=progress) super().load_state_dict(state_dict) + if device is None: + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + self.device = device self.preprocess = preprocess self.postprocess = postprocess if self.postprocess: @@ -162,10 +165,12 @@ def __init__(self, urls, pretrained=True, preprocess=True, postprocess=True, pro ) self.pproc.load_state_dict(state_dict) + self.to(self.device) def forward(self, x, fs=None): if self.preprocess: x = self._preprocess(x, fs) + x = x.to(self.device) x = VGG.forward(self, x) if self.postprocess: x = self._postprocess(x)