diff --git a/src/turnkeyml/cli/cli.py b/src/turnkeyml/cli/cli.py index 89b9955..fbb0832 100644 --- a/src/turnkeyml/cli/cli.py +++ b/src/turnkeyml/cli/cli.py @@ -2,7 +2,7 @@ import sys import os from difflib import get_close_matches -from typing import List +from typing import List, Dict, Tuple, Any import turnkeyml.common.filesystem as fs from turnkeyml.sequence import Sequence from turnkeyml.tools import Tool, FirstTool, NiceHelpFormatter @@ -63,25 +63,23 @@ def _check_extension( return file_name -def main(): +def parse_tools( + parser: argparse.ArgumentParser, supported_tools: List[Tool] +) -> Tuple[Dict[str, Any], Dict[Tool, List[str]], List[str]]: + """ + Add the help for parsing tools and their args to an ArgumentParser. - tool_parsers = {tool.unique_name: tool.parser() for tool in SUPPORTED_TOOLS} - tool_classes = {tool.unique_name: tool for tool in SUPPORTED_TOOLS} + Then, perform the task of parsing a full turnkey CLI command including + teasing apart the global arguments and separate tool invocations. + """ - # Define the argument parser - parser = CustomArgumentParser( - description="This utility runs tools in a sequence. " - "To use it, provide a list of tools and " - "their arguments. See " - "https://github.com/onnx/turnkeyml/blob/main/docs/tools_user_guide.md " - "to learn the exact syntax.\n\nExample: turnkey -i my_model.py discover export-pytorch", - formatter_class=NiceHelpFormatter, - ) + tool_parsers = {tool.unique_name: tool.parser() for tool in supported_tools} + tool_classes = {tool.unique_name: tool for tool in supported_tools} # Sort tools into categories and format for the help menu - first_tool_choices = _tool_list_help(SUPPORTED_TOOLS, FirstTool) - eval_tool_choices = _tool_list_help(SUPPORTED_TOOLS, Tool, exclude=FirstTool) - mgmt_tool_choices = _tool_list_help(SUPPORTED_TOOLS, ManagementTool) + first_tool_choices = _tool_list_help(supported_tools, FirstTool) + eval_tool_choices = _tool_list_help(supported_tools, Tool, exclude=FirstTool) + mgmt_tool_choices = _tool_list_help(supported_tools, ManagementTool) tools_action = parser.add_argument( "tools", @@ -101,67 +99,6 @@ def main(): choices=tool_parsers.keys(), ) - parser.add_argument( - "-i", - "--input-files", - nargs="+", - help="One or more inputs that will be evaluated by the tool sequence " - "(e.g., script (.py), ONNX (.onnx), turnkey build state (state.yaml), " - "input list (.txt) files)", - type=lambda file: _check_extension( - ("py", "onnx", "txt", "yaml"), file, parser.error, tool_classes - ), - ) - - parser.add_argument( - "-d", - "--cache-dir", - help="Build cache directory where results will " - f"be stored (defaults to {fs.DEFAULT_CACHE_DIR})", - required=False, - default=fs.DEFAULT_CACHE_DIR, - ) - - parser.add_argument( - "--lean-cache", - dest="lean_cache", - help="Delete all build artifacts (e.g., .onnx files) when the command completes", - action="store_true", - ) - - parser.add_argument( - "--labels", - dest="labels", - help="Filter the --input-files to only include files that have the provided labels", - nargs="*", - default=[], - ) - - slurm_or_processes_group = parser.add_mutually_exclusive_group() - - slurm_or_processes_group.add_argument( - "--use-slurm", - dest="use_slurm", - help="Execute on Slurm instead of using local compute resources", - action="store_true", - ) - - slurm_or_processes_group.add_argument( - "--process-isolation", - dest="process_isolation", - help="Isolate evaluating each input into a separate process", - action="store_true", - ) - - parser.add_argument( - "--timeout", - type=int, - default=None, - help="Build timeout, in seconds, after which a build will be canceled " - f"(default={DEFAULT_TIMEOUT_SECONDS}). Only " - "applies when --process-isolation or --use-slurm is also used.", - ) - # run as if "-h" was passed if no parameters are passed if len(sys.argv) == 1: sys.argv.append("-h") @@ -229,9 +166,91 @@ def main(): # Convert tool names into Tool instances tool_instances = {tool_classes[cmd](): argv for cmd, argv in tools_invoked.items()} + evaluation_tools = [tool_classes[cmd] for cmd in evaluation_tools] + + return global_args, tool_instances, evaluation_tools + + +def main(): + + # Define the argument parser + parser = CustomArgumentParser( + description="This utility runs tools in a sequence. " + "To use it, provide a list of tools and " + "their arguments. See " + "https://github.com/onnx/turnkeyml/blob/main/docs/tools_user_guide.md " + "to learn the exact syntax.\n\nExample: turnkey -i my_model.py discover export-pytorch", + formatter_class=NiceHelpFormatter, + ) + + parser.add_argument( + "-i", + "--input-files", + nargs="+", + help="One or more inputs that will be evaluated by the tool sequence " + "(e.g., script (.py), ONNX (.onnx), turnkey build state (state.yaml), " + "input list (.txt) files)", + type=lambda file: _check_extension( + ("py", "onnx", "txt", "yaml"), + file, + parser.error, + {tool.unique_name: tool for tool in SUPPORTED_TOOLS}, + ), + ) + + parser.add_argument( + "-d", + "--cache-dir", + help="Build cache directory where results will " + f"be stored (defaults to {fs.DEFAULT_CACHE_DIR})", + required=False, + default=fs.DEFAULT_CACHE_DIR, + ) + + parser.add_argument( + "--lean-cache", + dest="lean_cache", + help="Delete all build artifacts (e.g., .onnx files) when the command completes", + action="store_true", + ) + + parser.add_argument( + "--labels", + dest="labels", + help="Filter the --input-files to only include files that have the provided labels", + nargs="*", + default=[], + ) + + slurm_or_processes_group = parser.add_mutually_exclusive_group() + + slurm_or_processes_group.add_argument( + "--use-slurm", + dest="use_slurm", + help="Execute on Slurm instead of using local compute resources", + action="store_true", + ) + + slurm_or_processes_group.add_argument( + "--process-isolation", + dest="process_isolation", + help="Isolate evaluating each input into a separate process", + action="store_true", + ) + + parser.add_argument( + "--timeout", + type=int, + default=None, + help="Build timeout, in seconds, after which a build will be canceled " + f"(default={DEFAULT_TIMEOUT_SECONDS}). Only " + "applies when --process-isolation or --use-slurm is also used.", + ) + + global_args, tool_instances, evaluation_tools = parse_tools(parser, SUPPORTED_TOOLS) if len(evaluation_tools) > 0: - if not issubclass(tool_classes[evaluation_tools[0]], FirstTool): + if not issubclass(evaluation_tools[0], FirstTool): parser.error( "The first tool in the sequence needs to be one " "of the 'tools that can start a sequence.' Use " diff --git a/src/turnkeyml/common/status.py b/src/turnkeyml/common/status.py index f72f9e0..1e1fb1b 100644 --- a/src/turnkeyml/common/status.py +++ b/src/turnkeyml/common/status.py @@ -5,6 +5,7 @@ from typing import Callable, List, Union, Dict, Optional import torch from turnkeyml.common import printing +from turnkeyml.state import State import turnkeyml.common.build as build import turnkeyml.common.filesystem as fs import turnkeyml.common.analyze_model as analyze_model @@ -362,3 +363,45 @@ def stop_logger_forward() -> None: sys.stdout = sys.stdout.terminal if hasattr(sys.stderr, "terminal_err"): sys.stderr = sys.stderr.terminal_err + + +def add_to_state( + state: State, + name: str, + model: Union[str, torch.nn.Module], + extension: str = "", + input_shapes: Optional[Dict] = None, +): + if vars(state).get("model_hash"): + model_hash = state.model_hash + else: + model_hash = 0 + + if os.path.exists(name): + file_name = fs.clean_file_name(name) + file = name + else: + file_name = name + file = "" + + state.invocation_info = UniqueInvocationInfo( + name=input, + script_name=file_name, + file=file, + input_shapes=input_shapes, + hash=model_hash, + is_target=True, + extension=extension, + executed=1, + ) + state.models_found = { + "the_model": ModelInfo( + model=model, + name=input, + script_name=input, + file=input, + unique_invocations={model_hash: state.invocation_info}, + hash=model_hash, + ) + } + state.invocation_info.params = state.models_found["the_model"].params diff --git a/src/turnkeyml/tools/load_build.py b/src/turnkeyml/tools/load_build.py index 0f760c5..09c1a93 100644 --- a/src/turnkeyml/tools/load_build.py +++ b/src/turnkeyml/tools/load_build.py @@ -6,7 +6,7 @@ import turnkeyml.common.exceptions as exp import turnkeyml.common.build as build import turnkeyml.common.filesystem as fs -from turnkeyml.common.status import ModelInfo, UniqueInvocationInfo +import turnkeyml.common.status as status from turnkeyml.state import State, load_state import turnkeyml.common.printing as printing from turnkeyml.version import __version__ as turnkey_version @@ -186,24 +186,6 @@ def run(self, state: State, input: str = "", skip_policy=skip_policy_default): # Create a UniqueInvocationInfo and ModelInfo so that we can display status # at the end of the sequence - state.invocation_info = UniqueInvocationInfo( - name=input, - script_name=fs.clean_file_name(input), - hash=0, - is_target=True, - extension="_state.yaml", - executed=1, - ) - state.models_found = { - "state_file": ModelInfo( - model=input, - name=input, - script_name=input, - file=input, - unique_invocations={0: state.invocation_info}, - hash=0, - ) - } - state.invocation_info.params = state.models_found["state_file"].params + status.add_to_state(state=state, name=input, model=input) return state diff --git a/src/turnkeyml/tools/onnx.py b/src/turnkeyml/tools/onnx.py index 9e2d2e7..a7296c9 100644 --- a/src/turnkeyml/tools/onnx.py +++ b/src/turnkeyml/tools/onnx.py @@ -13,7 +13,7 @@ import turnkeyml.common.tensor_helpers as tensor_helpers import turnkeyml.common.onnx_helpers as onnx_helpers import turnkeyml.common.filesystem as fs -from turnkeyml.common.status import ModelInfo, UniqueInvocationInfo +import turnkeyml.common.status as status from turnkeyml.state import State @@ -154,27 +154,13 @@ def run(self, state: State, input: str = ""): # Create a UniqueInvocationInfo and ModelInfo so that we can display status # at the end of the sequence - state.invocation_info = UniqueInvocationInfo( + status.add_to_state( + state=state, name=onnx_file, - script_name=fs.clean_file_name(onnx_file), - file=onnx_file, - input_shapes={key: value.shape for key, value in state.inputs.items()}, - hash=state.model_hash, - is_target=True, + model=onnx_file, extension=".onnx", - executed=1, + input_shapes={key: value.shape for key, value in state.inputs.items()}, ) - state.models_found = { - "onnx_file": ModelInfo( - model=onnx_file, - name=onnx_file, - script_name=onnx_file, - file=onnx_file, - unique_invocations={state.model_hash: state.invocation_info}, - hash=state.model_hash, - ) - } - state.invocation_info.params = state.models_found["onnx_file"].params return state diff --git a/src/turnkeyml/tools/tool.py b/src/turnkeyml/tools/tool.py index 273f88f..de77d31 100644 --- a/src/turnkeyml/tools/tool.py +++ b/src/turnkeyml/tools/tool.py @@ -161,6 +161,7 @@ def status_line(self, successful, verbosity): def __init__( self, monitor_message, + enable_logger=True, ): _name_is_file_safe(self.__class__.unique_name) @@ -169,6 +170,9 @@ def __init__( self.monitor_message = monitor_message self.progress = None self.logfile_path = None + # Tools can disable build.Logger, which captures all stdout and stderr from + # the Tool, by setting enable_logger=False + self.enable_logger = enable_logger # Tools can provide a list of keys that can be found in # evaluation stats. Those key:value pairs will be presented # in the status at the end of the build. @@ -251,7 +255,11 @@ def run_helper( try: # Execute the build tool - with build.Logger(self.monitor_message, self.logfile_path): + + if self.enable_logger: + with build.Logger(self.monitor_message, self.logfile_path): + state = self.run(state, **kwargs) + else: state = self.run(state, **kwargs) except Exception: # pylint: disable=broad-except diff --git a/src/turnkeyml/version.py b/src/turnkeyml/version.py index 131942e..8d1c862 100644 --- a/src/turnkeyml/version.py +++ b/src/turnkeyml/version.py @@ -1 +1 @@ -__version__ = "3.0.2" +__version__ = "3.0.3"