From ccfacce90079dab1b262f8bd32cb89ba3d8b760a Mon Sep 17 00:00:00 2001 From: Christoph Gerum Date: Thu, 11 Jan 2024 10:34:56 +0100 Subject: [PATCH] Fix more tests --- test/test_train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_train.py b/test/test_train.py index 11299593..055a25c5 100644 --- a/test/test_train.py +++ b/test/test_train.py @@ -104,7 +104,7 @@ def test_datasets(model, dataset, split): "model", ["timm_resnet50", "timm_efficientnet_lite1", "timm_focalnet_base_srf"] ) def test_2d(model): - command_line = f"hannah-train module=image_classifier dataset=fake2d features=identity trainer.devices=[0] model={model} trainer.fast_dev_run=true scheduler.max_lr=2.5 module.batch_size=2" + command_line = f"hannah-train module=image_classifier dataset=fake2d features=identity trainer.devices=1 model={model} trainer.fast_dev_run=true scheduler.max_lr=2.5 module.batch_size=2" subprocess.run(command_line, shell=True, check=True, cwd=topdir) @@ -113,7 +113,7 @@ def test_2d(model): "model", ["timm_resnet50", "timm_efficientnet_lite1", "timm_resnet18"] ) def test_cifar_2d(model): - command_line = f"hannah-train module=image_classifier dataset=cifar10 features=identity trainer.devices=[0] model={model} trainer.fast_dev_run=true scheduler.max_lr=2.5 module.batch_size=2" + command_line = f"hannah-train module=image_classifier dataset=cifar10 features=identity trainer.devices=1 model={model} trainer.fast_dev_run=true scheduler.max_lr=2.5 module.batch_size=2" subprocess.run(command_line, shell=True, check=True, cwd=topdir)