Skip to content

Commit

Permalink
Merge branch '193-update-the-files-api' into 183-add-an-entry-stage-c…
Browse files Browse the repository at this point in the history
…lass
  • Loading branch information
jeremyfowers committed Jul 17, 2024
2 parents 6eba5f7 + a83badb commit e0b3098
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 67 deletions.
6 changes: 3 additions & 3 deletions examples/files_api/onnx_opset.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Example script that demonstrates how to set a custom ONNX opset for a benchmarking run
Example script that demonstrates how to set a custom ONNX opset for a build run
You can run this script in your turnkey Conda environment with:
python onnx_opset.py --onnx-opset YOUR_OPSET
Expand All @@ -10,7 +10,7 @@

import pathlib
import argparse
from turnkeyml import benchmark_files
from turnkeyml import evaluate_files
from turnkeyml.sequence import Sequence
from turnkeyml.tools.export import ExportPytorchModel
from turnkeyml.tools.discovery import Discover
Expand Down Expand Up @@ -39,7 +39,7 @@ def main():
sequence = Sequence(
tools={Discover(): [], ExportPytorchModel(): ["--opset", args.onnx_opset]}
)
benchmark_files(
evaluate_files(
input_files=[path_to_hello_world_script],
sequence=sequence,
)
Expand Down
13 changes: 2 additions & 11 deletions src/turnkeyml/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from turnkeyml.tools import Tool
from turnkeyml.sequence.tool_plugins import SUPPORTED_TOOLS
from turnkeyml.cli.spawn import DEFAULT_TIMEOUT_SECONDS
from turnkeyml.files_api import benchmark_files
from turnkeyml.files_api import evaluate_files
import turnkeyml.common.build as build
import turnkeyml.common.printing as printing
from turnkeyml.tools.management_tools import ManagementTool
Expand Down Expand Up @@ -156,15 +156,6 @@ def main():
default=[],
)

parser.add_argument(
"--rebuild",
choices=build.REBUILD_OPTIONS,
dest="rebuild",
help=f"Sets the cache rebuild policy (defaults to {build.DEFAULT_REBUILD_POLICY})",
required=False,
default=build.DEFAULT_REBUILD_POLICY,
)

slurm_or_processes_group = parser.add_mutually_exclusive_group()

slurm_or_processes_group.add_argument(
Expand Down Expand Up @@ -255,7 +246,7 @@ def main():
if len(evaluation_tools) > 0:
# Run the evaluation tools as a build
sequence = Sequence(tools=tool_instances)
benchmark_files(sequence=sequence, **global_args)
evaluate_files(sequence=sequence, **global_args)
else:
# Run the management tools
for management_tool, argv in tool_instances.items():
Expand Down
24 changes: 11 additions & 13 deletions src/turnkeyml/cli/spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,11 +122,6 @@ def parse_build_name(line: str, current_value: str) -> Optional[str]:
DEFAULT_TIMEOUT_SECONDS = 3600


class Target(Enum):
SLURM = "slurm"
LOCAL_PROCESS = "local_process"


def slurm_jobs_in_queue(job_name=None) -> List[str]:
"""Return the set of slurm jobs that are currently pending/running"""
user = getpass.getuser()
Expand Down Expand Up @@ -197,13 +192,14 @@ def run_turnkey(
build_name: str,
sequence: Sequence,
file_name: str,
target: Target,
process_isolation: bool,
use_slurm: bool,
cache_dir: str,
lean_cache: bool,
timeout: Optional[int] = DEFAULT_TIMEOUT_SECONDS,
working_dir: str = os.getcwd(),
ml_cache_dir: Optional[str] = os.environ.get("SLURM_ML_CACHE"),
max_jobs: int = 50,
**kwargs,
):
"""
Run turnkey on a single input file in a separate process (e.g., Slurm, subprocess).
Expand All @@ -213,6 +209,11 @@ def run_turnkey(
The key must be the snake_case version of the CLI argument (e.g, build_only for --build-only)
"""

if use_slurm and process_isolation:
raise ValueError(
"use_slurm and process_isolation are mutually exclusive, but both are True"
)

type_to_formatter = {
str: value_arg,
int: value_arg,
Expand All @@ -225,7 +226,7 @@ def run_turnkey(

# Add cache_dir to kwargs so that it gets processed
# with the other arguments
kwargs["cache_dir"] = cache_dir
kwargs = {"cache_dir": cache_dir, "lean_cache": lean_cache}

for key, value in kwargs.items():
if value is not None:
Expand All @@ -234,7 +235,7 @@ def run_turnkey(

invocation_args = invocation_args + " " + sequence_arg(sequence)

if target == Target.SLURM:
if use_slurm:
# Change args into the format expected by Slurm
slurm_args = " ".join(shlex.split(invocation_args))

Expand Down Expand Up @@ -276,7 +277,7 @@ def run_turnkey(

print(f"Submitting job {job_name} to Slurm")
subprocess.check_call(slurm_command)
elif target == Target.LOCAL_PROCESS:
else: # process isolation
command = "turnkey " + invocation_args
printing.log_info(f"Starting process with command: {command}")

Expand Down Expand Up @@ -373,6 +374,3 @@ def run_turnkey(
"Stats file found, but unable to perform cleanup due to "
f"exception: {stats_exception}"
)

else:
raise ValueError(f"Unsupported value for target: {target}.")
45 changes: 7 additions & 38 deletions src/turnkeyml/files_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,28 +48,22 @@ def unpack_txt_inputs(input_files: List[str]) -> List[str]:
return processed_files + [f for f in input_files if not f.endswith(".txt")]


# pylint: disable=unused-argument
def benchmark_files(
def evaluate_files(
input_files: List[str],
use_slurm: bool = False,
process_isolation: bool = False,
lean_cache: bool = False,
cache_dir: str = fs.DEFAULT_CACHE_DIR,
labels: List[str] = None,
rebuild: Optional[str] = None,
timeout: Optional[int] = None,
sequence: Union[Dict, Sequence] = None,
):

# Capture the function arguments so that we can forward them
# to downstream APIs
benchmarking_args = copy.deepcopy(locals())
regular_files = []

# Replace .txt files with the models listed inside them
input_files = unpack_txt_inputs(input_files)

# Iterate through each string in the input_files list
regular_files = []
for input_string in input_files:
if not any(char in input_string for char in "*?[]"):
regular_files.append(input_string)
Expand Down Expand Up @@ -97,19 +91,10 @@ def benchmark_files(
else:
timeout_to_use = spawn.DEFAULT_TIMEOUT_SECONDS

benchmarking_args["timeout"] = timeout_to_use

# Convert regular expressions in input files argument
# into full file paths (e.g., [*.py] -> [a.py, b.py] )
input_files_expanded = fs.expand_inputs(input_files)

# Do not forward arguments to downstream APIs
# that will be decoded in this function body
benchmarking_args.pop("input_files")
benchmarking_args.pop("labels")
benchmarking_args.pop("use_slurm")
benchmarking_args.pop("process_isolation")

# Make sure the cache directory exists
fs.make_cache_dir(cache_dir)

Expand Down Expand Up @@ -178,30 +163,14 @@ def benchmark_files(
continue

if use_slurm or process_isolation:
# Decode args into spawn.Target
if use_slurm and process_isolation:
raise ValueError(
"use_slurm and process_isolation are mutually exclusive, but both are True"
)
elif use_slurm:
process_type = spawn.Target.SLURM
elif process_isolation:
process_type = spawn.Target.LOCAL_PROCESS
else:
raise ValueError(
"This code path requires use_slurm or use_process to be True, "
"but both are False"
)

# We want to pass sequence in explicity
benchmarking_args.pop("sequence")

spawn.run_turnkey(
build_name=build_name,
sequence=sequence,
target=process_type,
file_name=encoded_input,
**benchmarking_args,
use_slurm=use_slurm,
process_isolation=process_isolation,
timeout=timeout_to_use,
lean_cache=lean_cache,
)

else:
Expand Down Expand Up @@ -265,7 +234,7 @@ def benchmark_files(
model=file_path_absolute,
sequence=sequence,
cache_dir=cache_dir,
rebuild=rebuild,
rebuild="always",
lean_cache=lean_cache,
)

Expand Down
2 changes: 0 additions & 2 deletions test/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,8 +395,6 @@ def test_008_cli_turnkey_args(self):
"turnkey",
"-i",
os.path.join(corpus_dir, test_script),
"--rebuild",
"always",
"--cache-dir",
cache_dir,
"discover",
Expand Down

0 comments on commit e0b3098

Please sign in to comment.