Skip to content

Commit

Permalink
Update sleepnet.py
Browse files Browse the repository at this point in the history
Fixed for implementation of pytorch use of cuda device
  • Loading branch information
LeandroCasiraghi authored Aug 2, 2023
1 parent ac8d94b commit 46ecaf1
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions src/asleep/sleepnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,12 @@ def setup_dataset(X, pid, cfg, is_train=False):


def config_device(cfg):
if cfg.gpu != -1:
my_device = "cuda:" + str(cfg.gpu)
if cfg.gpu != 'cpu':
my_device = str(cfg.gpu)
print ("pytorch device: "+my_device)
else:
my_device = "cpu"
print ("pytorch device defaulting to 'cpu'")
return my_device


Expand Down Expand Up @@ -198,7 +200,7 @@ def sleepnet_inference(X, pid, weight_path, cfg, local_repo_path=""):
return aligned_y_pred, test_pid


def start_sleep_net(X, pid, data_root, weight_path, device_id=-1, local_repo_path=""):
def start_sleep_net(X, pid, data_root, weight_path, device_id='cpu', local_repo_path=""):
initialize(config_path="conf")
cfg = compose(
"config_eval",
Expand Down

0 comments on commit 46ecaf1

Please sign in to comment.