Skip to content

Commit

Permalink
Improve TurnkeyML Framework Extensibility (#209)
Browse files Browse the repository at this point in the history
* Allow disabling the build logger from inside a Tool

* Allow creating custom CLIs

* reform status

* rev version number
  • Loading branch information
jeremyfowers authored Aug 2, 2024
1 parent a3695e9 commit abbf3cb
Show file tree
Hide file tree
Showing 6 changed files with 157 additions and 119 deletions.
175 changes: 97 additions & 78 deletions src/turnkeyml/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import sys
import os
from difflib import get_close_matches
from typing import List
from typing import List, Dict, Tuple, Any
import turnkeyml.common.filesystem as fs
from turnkeyml.sequence import Sequence
from turnkeyml.tools import Tool, FirstTool, NiceHelpFormatter
Expand Down Expand Up @@ -63,25 +63,23 @@ def _check_extension(
return file_name


def main():
def parse_tools(
parser: argparse.ArgumentParser, supported_tools: List[Tool]
) -> Tuple[Dict[str, Any], Dict[Tool, List[str]], List[str]]:
"""
Add the help for parsing tools and their args to an ArgumentParser.
tool_parsers = {tool.unique_name: tool.parser() for tool in SUPPORTED_TOOLS}
tool_classes = {tool.unique_name: tool for tool in SUPPORTED_TOOLS}
Then, perform the task of parsing a full turnkey CLI command including
teasing apart the global arguments and separate tool invocations.
"""

# Define the argument parser
parser = CustomArgumentParser(
description="This utility runs tools in a sequence. "
"To use it, provide a list of tools and "
"their arguments. See "
"https://github.com/onnx/turnkeyml/blob/main/docs/tools_user_guide.md "
"to learn the exact syntax.\n\nExample: turnkey -i my_model.py discover export-pytorch",
formatter_class=NiceHelpFormatter,
)
tool_parsers = {tool.unique_name: tool.parser() for tool in supported_tools}
tool_classes = {tool.unique_name: tool for tool in supported_tools}

# Sort tools into categories and format for the help menu
first_tool_choices = _tool_list_help(SUPPORTED_TOOLS, FirstTool)
eval_tool_choices = _tool_list_help(SUPPORTED_TOOLS, Tool, exclude=FirstTool)
mgmt_tool_choices = _tool_list_help(SUPPORTED_TOOLS, ManagementTool)
first_tool_choices = _tool_list_help(supported_tools, FirstTool)
eval_tool_choices = _tool_list_help(supported_tools, Tool, exclude=FirstTool)
mgmt_tool_choices = _tool_list_help(supported_tools, ManagementTool)

tools_action = parser.add_argument(
"tools",
Expand All @@ -101,67 +99,6 @@ def main():
choices=tool_parsers.keys(),
)

parser.add_argument(
"-i",
"--input-files",
nargs="+",
help="One or more inputs that will be evaluated by the tool sequence "
"(e.g., script (.py), ONNX (.onnx), turnkey build state (state.yaml), "
"input list (.txt) files)",
type=lambda file: _check_extension(
("py", "onnx", "txt", "yaml"), file, parser.error, tool_classes
),
)

parser.add_argument(
"-d",
"--cache-dir",
help="Build cache directory where results will "
f"be stored (defaults to {fs.DEFAULT_CACHE_DIR})",
required=False,
default=fs.DEFAULT_CACHE_DIR,
)

parser.add_argument(
"--lean-cache",
dest="lean_cache",
help="Delete all build artifacts (e.g., .onnx files) when the command completes",
action="store_true",
)

parser.add_argument(
"--labels",
dest="labels",
help="Filter the --input-files to only include files that have the provided labels",
nargs="*",
default=[],
)

slurm_or_processes_group = parser.add_mutually_exclusive_group()

slurm_or_processes_group.add_argument(
"--use-slurm",
dest="use_slurm",
help="Execute on Slurm instead of using local compute resources",
action="store_true",
)

slurm_or_processes_group.add_argument(
"--process-isolation",
dest="process_isolation",
help="Isolate evaluating each input into a separate process",
action="store_true",
)

parser.add_argument(
"--timeout",
type=int,
default=None,
help="Build timeout, in seconds, after which a build will be canceled "
f"(default={DEFAULT_TIMEOUT_SECONDS}). Only "
"applies when --process-isolation or --use-slurm is also used.",
)

# run as if "-h" was passed if no parameters are passed
if len(sys.argv) == 1:
sys.argv.append("-h")
Expand Down Expand Up @@ -229,9 +166,91 @@ def main():

# Convert tool names into Tool instances
tool_instances = {tool_classes[cmd](): argv for cmd, argv in tools_invoked.items()}
evaluation_tools = [tool_classes[cmd] for cmd in evaluation_tools]

return global_args, tool_instances, evaluation_tools


def main():

# Define the argument parser
parser = CustomArgumentParser(
description="This utility runs tools in a sequence. "
"To use it, provide a list of tools and "
"their arguments. See "
"https://github.com/onnx/turnkeyml/blob/main/docs/tools_user_guide.md "
"to learn the exact syntax.\n\nExample: turnkey -i my_model.py discover export-pytorch",
formatter_class=NiceHelpFormatter,
)

parser.add_argument(
"-i",
"--input-files",
nargs="+",
help="One or more inputs that will be evaluated by the tool sequence "
"(e.g., script (.py), ONNX (.onnx), turnkey build state (state.yaml), "
"input list (.txt) files)",
type=lambda file: _check_extension(
("py", "onnx", "txt", "yaml"),
file,
parser.error,
{tool.unique_name: tool for tool in SUPPORTED_TOOLS},
),
)

parser.add_argument(
"-d",
"--cache-dir",
help="Build cache directory where results will "
f"be stored (defaults to {fs.DEFAULT_CACHE_DIR})",
required=False,
default=fs.DEFAULT_CACHE_DIR,
)

parser.add_argument(
"--lean-cache",
dest="lean_cache",
help="Delete all build artifacts (e.g., .onnx files) when the command completes",
action="store_true",
)

parser.add_argument(
"--labels",
dest="labels",
help="Filter the --input-files to only include files that have the provided labels",
nargs="*",
default=[],
)

slurm_or_processes_group = parser.add_mutually_exclusive_group()

slurm_or_processes_group.add_argument(
"--use-slurm",
dest="use_slurm",
help="Execute on Slurm instead of using local compute resources",
action="store_true",
)

slurm_or_processes_group.add_argument(
"--process-isolation",
dest="process_isolation",
help="Isolate evaluating each input into a separate process",
action="store_true",
)

parser.add_argument(
"--timeout",
type=int,
default=None,
help="Build timeout, in seconds, after which a build will be canceled "
f"(default={DEFAULT_TIMEOUT_SECONDS}). Only "
"applies when --process-isolation or --use-slurm is also used.",
)

global_args, tool_instances, evaluation_tools = parse_tools(parser, SUPPORTED_TOOLS)

if len(evaluation_tools) > 0:
if not issubclass(tool_classes[evaluation_tools[0]], FirstTool):
if not issubclass(evaluation_tools[0], FirstTool):
parser.error(
"The first tool in the sequence needs to be one "
"of the 'tools that can start a sequence.' Use "
Expand Down
43 changes: 43 additions & 0 deletions src/turnkeyml/common/status.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Callable, List, Union, Dict, Optional
import torch
from turnkeyml.common import printing
from turnkeyml.state import State
import turnkeyml.common.build as build
import turnkeyml.common.filesystem as fs
import turnkeyml.common.analyze_model as analyze_model
Expand Down Expand Up @@ -362,3 +363,45 @@ def stop_logger_forward() -> None:
sys.stdout = sys.stdout.terminal
if hasattr(sys.stderr, "terminal_err"):
sys.stderr = sys.stderr.terminal_err


def add_to_state(
state: State,
name: str,
model: Union[str, torch.nn.Module],
extension: str = "",
input_shapes: Optional[Dict] = None,
):
if vars(state).get("model_hash"):
model_hash = state.model_hash
else:
model_hash = 0

if os.path.exists(name):
file_name = fs.clean_file_name(name)
file = name
else:
file_name = name
file = ""

state.invocation_info = UniqueInvocationInfo(
name=input,
script_name=file_name,
file=file,
input_shapes=input_shapes,
hash=model_hash,
is_target=True,
extension=extension,
executed=1,
)
state.models_found = {
"the_model": ModelInfo(
model=model,
name=input,
script_name=input,
file=input,
unique_invocations={model_hash: state.invocation_info},
hash=model_hash,
)
}
state.invocation_info.params = state.models_found["the_model"].params
22 changes: 2 additions & 20 deletions src/turnkeyml/tools/load_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import turnkeyml.common.exceptions as exp
import turnkeyml.common.build as build
import turnkeyml.common.filesystem as fs
from turnkeyml.common.status import ModelInfo, UniqueInvocationInfo
import turnkeyml.common.status as status
from turnkeyml.state import State, load_state
import turnkeyml.common.printing as printing
from turnkeyml.version import __version__ as turnkey_version
Expand Down Expand Up @@ -186,24 +186,6 @@ def run(self, state: State, input: str = "", skip_policy=skip_policy_default):

# Create a UniqueInvocationInfo and ModelInfo so that we can display status
# at the end of the sequence
state.invocation_info = UniqueInvocationInfo(
name=input,
script_name=fs.clean_file_name(input),
hash=0,
is_target=True,
extension="_state.yaml",
executed=1,
)
state.models_found = {
"state_file": ModelInfo(
model=input,
name=input,
script_name=input,
file=input,
unique_invocations={0: state.invocation_info},
hash=0,
)
}
state.invocation_info.params = state.models_found["state_file"].params
status.add_to_state(state=state, name=input, model=input)

return state
24 changes: 5 additions & 19 deletions src/turnkeyml/tools/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import turnkeyml.common.tensor_helpers as tensor_helpers
import turnkeyml.common.onnx_helpers as onnx_helpers
import turnkeyml.common.filesystem as fs
from turnkeyml.common.status import ModelInfo, UniqueInvocationInfo
import turnkeyml.common.status as status
from turnkeyml.state import State


Expand Down Expand Up @@ -154,27 +154,13 @@ def run(self, state: State, input: str = ""):

# Create a UniqueInvocationInfo and ModelInfo so that we can display status
# at the end of the sequence
state.invocation_info = UniqueInvocationInfo(
status.add_to_state(
state=state,
name=onnx_file,
script_name=fs.clean_file_name(onnx_file),
file=onnx_file,
input_shapes={key: value.shape for key, value in state.inputs.items()},
hash=state.model_hash,
is_target=True,
model=onnx_file,
extension=".onnx",
executed=1,
input_shapes={key: value.shape for key, value in state.inputs.items()},
)
state.models_found = {
"onnx_file": ModelInfo(
model=onnx_file,
name=onnx_file,
script_name=onnx_file,
file=onnx_file,
unique_invocations={state.model_hash: state.invocation_info},
hash=state.model_hash,
)
}
state.invocation_info.params = state.models_found["onnx_file"].params

return state

Expand Down
10 changes: 9 additions & 1 deletion src/turnkeyml/tools/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ def status_line(self, successful, verbosity):
def __init__(
self,
monitor_message,
enable_logger=True,
):
_name_is_file_safe(self.__class__.unique_name)

Expand All @@ -169,6 +170,9 @@ def __init__(
self.monitor_message = monitor_message
self.progress = None
self.logfile_path = None
# Tools can disable build.Logger, which captures all stdout and stderr from
# the Tool, by setting enable_logger=False
self.enable_logger = enable_logger
# Tools can provide a list of keys that can be found in
# evaluation stats. Those key:value pairs will be presented
# in the status at the end of the build.
Expand Down Expand Up @@ -251,7 +255,11 @@ def run_helper(

try:
# Execute the build tool
with build.Logger(self.monitor_message, self.logfile_path):

if self.enable_logger:
with build.Logger(self.monitor_message, self.logfile_path):
state = self.run(state, **kwargs)
else:
state = self.run(state, **kwargs)

except Exception: # pylint: disable=broad-except
Expand Down
2 changes: 1 addition & 1 deletion src/turnkeyml/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "3.0.2"
__version__ = "3.0.3"

0 comments on commit abbf3cb

Please sign in to comment.