-
Notifications
You must be signed in to change notification settings - Fork 16
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
fd53f71
commit ccbcfae
Showing
1 changed file
with
187 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,187 @@ | ||
import pathlib | ||
import copy | ||
import argparse | ||
from typing import Union, Dict | ||
from turnkeyml.tools import FirstTool | ||
import turnkeyml.common.exceptions as exp | ||
import turnkeyml.common.build as build | ||
import turnkeyml.common.filesystem as fs | ||
from turnkeyml.common.status import ModelInfo, UniqueInvocationInfo | ||
from turnkeyml.state import State, load_state | ||
import turnkeyml.common.printing as printing | ||
from turnkeyml.version import __version__ as turnkey_version | ||
|
||
skip_policy_default = "attempted" | ||
|
||
|
||
def _decode_version_number(version: str) -> Dict[str, int]: | ||
numbers = [int(x) for x in version.split(".")] | ||
return {"major": numbers[0], "minor": numbers[1], "patch": numbers[0]} | ||
|
||
|
||
class LoadBuild(FirstTool): | ||
""" | ||
Tool that loads a build from a previous usage of TurnkeyML and passes | ||
its saved State on to the next tool in the sequence. | ||
Works best with build State that is complete on disk. | ||
For example: | ||
- State that references an ONNX file is a good target, because the ONNX file can | ||
be loaded from disk. | ||
- State that references a PyTorch model in memory is a poor target, because | ||
that PyTorch model will not be available when the State file is loaded | ||
from disk. | ||
Expected inputs: None | ||
Outputs: | ||
- state has the contents of the state.yaml file of the target build. | ||
""" | ||
|
||
unique_name = "load-build" | ||
|
||
def __init__(self): | ||
super().__init__(monitor_message="Loading cached build") | ||
|
||
@staticmethod | ||
def parser(add_help: bool = True) -> argparse.ArgumentParser: | ||
parser = __class__.helpful_parser( | ||
description="Load build state from the cache", | ||
add_help=add_help, | ||
) | ||
|
||
parser.add_argument( | ||
"--skip-policy", | ||
choices=[skip_policy_default, "failed", "successful", "none"], | ||
help="Sets the policy for skipping evaluation attempts " | ||
f"(defaults to {skip_policy_default})." | ||
"`attempted` means to skip any previously-attempted evaluation, " | ||
"whether it succeeded or failed." | ||
"`failed` skips evaluations that have already failed once." | ||
"`successful` skips evaluations that have already succeeded." | ||
"`none` will attempt all evaluations, regardless of whether " | ||
"they were previously attempted.", | ||
required=False, | ||
default=skip_policy_default, | ||
) | ||
|
||
return parser | ||
|
||
def run(self, state: State, input: str = "", skip_policy=skip_policy_default): | ||
|
||
source_build_dir = pathlib.Path(input).parent | ||
source_build_dir_name = source_build_dir.name | ||
source_cache_dir = source_build_dir.parent | ||
|
||
# Make sure that the target yaml file is actually the state of a turnkey build | ||
if not fs.is_build_dir(source_cache_dir, source_build_dir_name): | ||
raise exp.CacheError( | ||
f"No build found at path: {input}. " | ||
"Try running `turnkey cache --list --all` to see the builds in your build cache." | ||
) | ||
|
||
# Save the new sequence's information so that we can append it to the | ||
# loaded build's sequence information later | ||
new_sequence_info = state.sequence_info | ||
|
||
# Load the cached build | ||
printing.log_info(f"Attempting to load: {input}") | ||
state = load_state(state_path=input) | ||
# Save the sequence of the prior build so that we can test against it later | ||
prior_selected_sequence = list(state.sequence_info.keys()) | ||
|
||
if state.build_status != build.FunctionStatus.SUCCESSFUL: | ||
print(f"Warning: loaded build status is {state.build_status}") | ||
|
||
# Raise an exception if there is a version mismatch between the installed | ||
# version of turnkey and the version of turnkey used to create the loaded | ||
# build | ||
current_version_decoded = _decode_version_number(turnkey_version) | ||
state_version_decoded = _decode_version_number(state.turnkey_version) | ||
out_of_date: Union[str, bool] = False | ||
if current_version_decoded["major"] > state_version_decoded["major"]: | ||
out_of_date = "major" | ||
elif current_version_decoded["minor"] > state_version_decoded["minor"]: | ||
out_of_date = "minor" | ||
|
||
if out_of_date: | ||
raise exp.SkipBuild( | ||
f"Your build {state.build_name} was previously built against " | ||
f"turnkey version {state.turnkey_version}, " | ||
f"however you are now using turnkey version {turnkey_version}. " | ||
"The previous build is " | ||
f"incompatible with this version of turnkey, as indicated by the {out_of_date} " | ||
"version number changing. See **docs/versioning.md** for details." | ||
) | ||
|
||
# Append the sequence of this build to the sequence of the loaded build | ||
stats = fs.Stats(state.cache_dir, state.build_name) | ||
combined_selected_sequence = copy.deepcopy(prior_selected_sequence) | ||
for new_tool, new_tool_args in new_sequence_info.items(): | ||
combined_selected_sequence.append(new_tool) | ||
state.sequence_info[new_tool] = new_tool_args | ||
stats.save_stat(fs.Keys.SELECTED_SEQUENCE_OF_TOOLS, combined_selected_sequence) | ||
|
||
# Apply the skip policy by skipping over this iteration of the | ||
# loop if the evaluation's pre-existing build status doesn't | ||
# meet certain criteria | ||
if self.__class__.unique_name not in prior_selected_sequence: | ||
# This build has not been attempted by load_build yet, so there | ||
# is no condition under which it should be skipped | ||
pass | ||
else: | ||
if skip_policy == "attempted": | ||
raise exp.SkipBuild( | ||
f"Skipping {state.build_name} because it was previously attempted " | ||
f"and the skip policy is set to {skip_policy}" | ||
) | ||
elif ( | ||
skip_policy == "successful" | ||
and state.build_status == build.FunctionStatus.SUCCESSFUL | ||
): | ||
raise exp.SkipBuild( | ||
f"Skipping {state.build_name} because it was previously successfully attempted " | ||
f"and the skip policy is set to {skip_policy}" | ||
) | ||
elif ( | ||
skip_policy == "failed" | ||
and state.build_status != build.FunctionStatus.SUCCESSFUL | ||
): | ||
raise exp.SkipBuild( | ||
f"Skipping {state.build_name} because it was previously unsuccessfully attempted " | ||
f"and the skip policy is set to {skip_policy}" | ||
) | ||
elif skip_policy == "none": | ||
# Skip policy of "none" means we should never skip over a build | ||
pass | ||
else: | ||
# The skip condition is not met, so we will continue | ||
pass | ||
|
||
# Mark the build status as incomplete now that we have re-opened it | ||
state.build_status = build.FunctionStatus.INCOMPLETE | ||
|
||
# Create a UniqueInvocationInfo and ModelInfo so that we can display status | ||
# at the end of the sequence | ||
state.invocation_info = UniqueInvocationInfo( | ||
name=input, | ||
script_name=fs.clean_file_name(input), | ||
hash=0, | ||
is_target=True, | ||
extension="_state.yaml", | ||
executed=1, | ||
) | ||
state.models_found = { | ||
"state_file": ModelInfo( | ||
model=input, | ||
name=input, | ||
script_name=input, | ||
file=input, | ||
unique_invocations={0: state.invocation_info}, | ||
hash=0, | ||
) | ||
} | ||
state.invocation_info.params = state.models_found["state_file"].params | ||
|
||
return state |