From 0304510203614a3b2d75fb44d7d64d1664c977b1 Mon Sep 17 00:00:00 2001 From: linnil1 Date: Fri, 13 Aug 2021 15:57:14 +0800 Subject: [PATCH] Fix cuda arguments for non-gpu server --- ezgeno/utils.py | 4 ++-- ezgeno/visualize.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/ezgeno/utils.py b/ezgeno/utils.py index db56e26..82ad54d 100644 --- a/ezgeno/utils.py +++ b/ezgeno/utils.py @@ -6,12 +6,12 @@ from torch.autograd import Variable from collections import defaultdict -def get_variable(inputs, cuda=False, **kwargs): +def get_variable(inputs, cuda=-1, **kwargs): if type(inputs) in [list, np.ndarray]: inputs = torch.Tensor(inputs) if cuda==-1: - out = Variable(inputs.cuda(), **kwargs) + out = Variable(inputs, **kwargs) else: out = Variable(inputs.to('cuda:%d'%cuda), **kwargs) return out diff --git a/ezgeno/visualize.py b/ezgeno/visualize.py index d3e00c2..9bf8e24 100644 --- a/ezgeno/visualize.py +++ b/ezgeno/visualize.py @@ -215,7 +215,7 @@ def show_grad_cam(args, model_path, data_name, model, use_cuda, window=9): #test_data = testset(args.dataSource,testLabelPath,testFileList) seq_length = test_data.seq_length test_loader = torch.utils.data.DataLoader(test_data, batch_size=1, num_workers=8) - grad_cam = GradCam(model=model, target_layer_names=args.target_layer_names, seq_length=seq_length, use_cuda=True) + grad_cam = GradCam(model=model, target_layer_names=args.target_layer_names, seq_length=seq_length, use_cuda=use_cuda) target_index = None pred_list = []