diff --git a/sotabench.py b/sotabench.py index 67816ff..dac1d76 100644 --- a/sotabench.py +++ b/sotabench.py @@ -55,10 +55,10 @@ def get_img_id(image_name): return image_name.split('/')[-1].replace('.JPEG', '') with torch.no_grad(): - for i, (input, target) in enumerate(test_loader): - input = input.to(device='cuda', non_blocking=True) + for i, (data, target) in enumerate(test_loader): + data = data.to(device='cuda', non_blocking=True) target = target.to(device='cuda', non_blocking=True) - output = model(input) + output = model(data) image_ids = [get_img_id(img[0]) for img in test_loader.dataset.imgs[i*test_loader.batch_size:(i+1)*test_loader.batch_size]] evaluator.add(dict(zip(image_ids, list(output.cpu().numpy())))) if evaluator.cache_exists: