Skip to content

Commit

Permalink
add the file
Browse files Browse the repository at this point in the history
  • Loading branch information
jeremyfowers committed Jul 17, 2024
1 parent fd53f71 commit ccbcfae
Showing 1 changed file with 187 additions and 0 deletions.
187 changes: 187 additions & 0 deletions src/turnkeyml/tools/load_build.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
import pathlib
import copy
import argparse
from typing import Union, Dict
from turnkeyml.tools import FirstTool
import turnkeyml.common.exceptions as exp
import turnkeyml.common.build as build
import turnkeyml.common.filesystem as fs
from turnkeyml.common.status import ModelInfo, UniqueInvocationInfo
from turnkeyml.state import State, load_state
import turnkeyml.common.printing as printing
from turnkeyml.version import __version__ as turnkey_version

skip_policy_default = "attempted"


def _decode_version_number(version: str) -> Dict[str, int]:
numbers = [int(x) for x in version.split(".")]
return {"major": numbers[0], "minor": numbers[1], "patch": numbers[0]}


class LoadBuild(FirstTool):
"""
Tool that loads a build from a previous usage of TurnkeyML and passes
its saved State on to the next tool in the sequence.
Works best with build State that is complete on disk.
For example:
- State that references an ONNX file is a good target, because the ONNX file can
be loaded from disk.
- State that references a PyTorch model in memory is a poor target, because
that PyTorch model will not be available when the State file is loaded
from disk.
Expected inputs: None
Outputs:
- state has the contents of the state.yaml file of the target build.
"""

unique_name = "load-build"

def __init__(self):
super().__init__(monitor_message="Loading cached build")

@staticmethod
def parser(add_help: bool = True) -> argparse.ArgumentParser:
parser = __class__.helpful_parser(
description="Load build state from the cache",
add_help=add_help,
)

parser.add_argument(
"--skip-policy",
choices=[skip_policy_default, "failed", "successful", "none"],
help="Sets the policy for skipping evaluation attempts "
f"(defaults to {skip_policy_default})."
"`attempted` means to skip any previously-attempted evaluation, "
"whether it succeeded or failed."
"`failed` skips evaluations that have already failed once."
"`successful` skips evaluations that have already succeeded."
"`none` will attempt all evaluations, regardless of whether "
"they were previously attempted.",
required=False,
default=skip_policy_default,
)

return parser

def run(self, state: State, input: str = "", skip_policy=skip_policy_default):

source_build_dir = pathlib.Path(input).parent
source_build_dir_name = source_build_dir.name
source_cache_dir = source_build_dir.parent

# Make sure that the target yaml file is actually the state of a turnkey build
if not fs.is_build_dir(source_cache_dir, source_build_dir_name):
raise exp.CacheError(
f"No build found at path: {input}. "
"Try running `turnkey cache --list --all` to see the builds in your build cache."
)

# Save the new sequence's information so that we can append it to the
# loaded build's sequence information later
new_sequence_info = state.sequence_info

# Load the cached build
printing.log_info(f"Attempting to load: {input}")
state = load_state(state_path=input)
# Save the sequence of the prior build so that we can test against it later
prior_selected_sequence = list(state.sequence_info.keys())

if state.build_status != build.FunctionStatus.SUCCESSFUL:
print(f"Warning: loaded build status is {state.build_status}")

# Raise an exception if there is a version mismatch between the installed
# version of turnkey and the version of turnkey used to create the loaded
# build
current_version_decoded = _decode_version_number(turnkey_version)
state_version_decoded = _decode_version_number(state.turnkey_version)
out_of_date: Union[str, bool] = False
if current_version_decoded["major"] > state_version_decoded["major"]:
out_of_date = "major"
elif current_version_decoded["minor"] > state_version_decoded["minor"]:
out_of_date = "minor"

if out_of_date:
raise exp.SkipBuild(
f"Your build {state.build_name} was previously built against "
f"turnkey version {state.turnkey_version}, "
f"however you are now using turnkey version {turnkey_version}. "
"The previous build is "
f"incompatible with this version of turnkey, as indicated by the {out_of_date} "
"version number changing. See **docs/versioning.md** for details."
)

# Append the sequence of this build to the sequence of the loaded build
stats = fs.Stats(state.cache_dir, state.build_name)
combined_selected_sequence = copy.deepcopy(prior_selected_sequence)
for new_tool, new_tool_args in new_sequence_info.items():
combined_selected_sequence.append(new_tool)
state.sequence_info[new_tool] = new_tool_args
stats.save_stat(fs.Keys.SELECTED_SEQUENCE_OF_TOOLS, combined_selected_sequence)

# Apply the skip policy by skipping over this iteration of the
# loop if the evaluation's pre-existing build status doesn't
# meet certain criteria
if self.__class__.unique_name not in prior_selected_sequence:
# This build has not been attempted by load_build yet, so there
# is no condition under which it should be skipped
pass
else:
if skip_policy == "attempted":
raise exp.SkipBuild(
f"Skipping {state.build_name} because it was previously attempted "
f"and the skip policy is set to {skip_policy}"
)
elif (
skip_policy == "successful"
and state.build_status == build.FunctionStatus.SUCCESSFUL
):
raise exp.SkipBuild(
f"Skipping {state.build_name} because it was previously successfully attempted "
f"and the skip policy is set to {skip_policy}"
)
elif (
skip_policy == "failed"
and state.build_status != build.FunctionStatus.SUCCESSFUL
):
raise exp.SkipBuild(
f"Skipping {state.build_name} because it was previously unsuccessfully attempted "
f"and the skip policy is set to {skip_policy}"
)
elif skip_policy == "none":
# Skip policy of "none" means we should never skip over a build
pass
else:
# The skip condition is not met, so we will continue
pass

# Mark the build status as incomplete now that we have re-opened it
state.build_status = build.FunctionStatus.INCOMPLETE

# Create a UniqueInvocationInfo and ModelInfo so that we can display status
# at the end of the sequence
state.invocation_info = UniqueInvocationInfo(
name=input,
script_name=fs.clean_file_name(input),
hash=0,
is_target=True,
extension="_state.yaml",
executed=1,
)
state.models_found = {
"state_file": ModelInfo(
model=input,
name=input,
script_name=input,
file=input,
unique_invocations={0: state.invocation_info},
hash=0,
)
}
state.invocation_info.params = state.models_found["state_file"].params

return state

0 comments on commit ccbcfae

Please sign in to comment.