diff --git a/test/__main__.py b/test/__main__.py index 474e1b47..ba38d4e6 100644 --- a/test/__main__.py +++ b/test/__main__.py @@ -12,7 +12,7 @@ running_out = 0 for file in here.iterdir(): if file.is_file() and file.name.startswith("test"): - cmd = f"pytest {file} " + flags + cmd = f"pytest -n 2 {file} " + flags out = subprocess.run(cmd, shell=True).returncode running_out = max(running_out, out) sys.exit(running_out) diff --git a/test/requirements.txt b/test/requirements.txt index 9de88eb6..723e3188 100644 --- a/test/requirements.txt +++ b/test/requirements.txt @@ -2,5 +2,6 @@ beartype jaxlib optax pytest +pytest-xdist scipy tqdm