Skip to content

Commit

Permalink
Merge pull request #19 from dfan/master
Browse files Browse the repository at this point in the history
Add device to VGGish and send audio tensor to device
  • Loading branch information
Harri Taylor authored Apr 8, 2021
2 parents e1e2273 + f70241b commit 4670116
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion torchvggish/vggish.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,12 +141,15 @@ def _vgg():


class VGGish(VGG):
def __init__(self, urls, pretrained=True, preprocess=True, postprocess=True, progress=True):
def __init__(self, urls, device=None, pretrained=True, preprocess=True, postprocess=True, progress=True):
super().__init__(make_layers())
if pretrained:
state_dict = hub.load_state_dict_from_url(urls['vggish'], progress=progress)
super().load_state_dict(state_dict)

if device is None:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.device = device
self.preprocess = preprocess
self.postprocess = postprocess
if self.postprocess:
Expand All @@ -162,10 +165,12 @@ def __init__(self, urls, pretrained=True, preprocess=True, postprocess=True, pro
)

self.pproc.load_state_dict(state_dict)
self.to(self.device)

def forward(self, x, fs=None):
if self.preprocess:
x = self._preprocess(x, fs)
x = x.to(self.device)
x = VGG.forward(self, x)
if self.postprocess:
x = self._postprocess(x)
Expand Down

0 comments on commit 4670116

Please sign in to comment.