Skip to content

Commit

Permalink
progress
Browse files Browse the repository at this point in the history
  • Loading branch information
jeremyfowers committed Jul 17, 2024
1 parent e0b3098 commit fd53f71
Show file tree
Hide file tree
Showing 9 changed files with 143 additions and 152 deletions.
3 changes: 1 addition & 2 deletions src/turnkeyml/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from turnkeyml.version import __version__

from .files_api import benchmark_files
from .files_api import evaluate_files
from .cli.cli import main as turnkeycli
from .sequence.build_api import build_model
from .state import load_state, State
2 changes: 2 additions & 0 deletions src/turnkeyml/common/analyze_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ def count_parameters(model: torch.nn.Module) -> int:
if tensor.name not in onnx_model.graph.input
)
)
elif isinstance(model, str) and model.endswith(".yaml"):
return None

# Raise exception if an unsupported model type is provided
raise AnalysisException(f"model type {type(model)} is not supported")
Expand Down
11 changes: 6 additions & 5 deletions src/turnkeyml/common/filesystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,8 @@ def clean_file_name(script_path: str) -> str:
If its a state.yaml file, trim the "state.yaml"
"""
if script_path.endswith("_state.yaml"):
return script_path.replace("_state.yaml", "")
return pathlib.Path(script_path).stem.replace("_state", "")
# return script_path.replace("_state.yaml", "")
else:
return pathlib.Path(script_path).stem

Expand All @@ -117,7 +118,7 @@ def _load_yaml(file) -> Dict:
return {}


def _save_yaml(dict: Dict, file):
def save_yaml(dict: Dict, file):
with open(file, "w", encoding="utf8") as outfile:
yaml.dump(dict, outfile)

Expand Down Expand Up @@ -395,7 +396,7 @@ def __init__(self, cache_dir: str, build_name: str):
os.makedirs(os.path.dirname(self.file), exist_ok=True)
if not os.path.exists(self.file):
# Start an empty stats file
_save_yaml({}, self.file)
save_yaml({}, self.file)

@property
def stats(self):
Expand Down Expand Up @@ -427,14 +428,14 @@ def save_stat(self, key: str, value):

self._set_key(stats_dict, [key], value)

_save_yaml(stats_dict, self.file)
save_yaml(stats_dict, self.file)

def save_sub_stat(self, parent_key: str, key: str, value):
stats_dict = self.stats

self._set_key(stats_dict, [parent_key, key], value)

_save_yaml(stats_dict, self.file)
save_yaml(stats_dict, self.file)

def save_eval_error_log(self, logfile_path):
if logfile_path is None:
Expand Down
10 changes: 5 additions & 5 deletions src/turnkeyml/common/status.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ class UniqueInvocationInfo(BasicInfo):
executed: int = 0
exec_time: float = 0.0
status_message: str = ""
extra_status: Optional[str] = None
extra_status: Optional[str] = ""
is_target: bool = False
auto_selected: bool = False
status_message_color: printing.Colors = printing.Colors.ENDC
Expand Down Expand Up @@ -106,7 +106,7 @@ def _print_heading(
print(f"{self.script_name}{self.extension}:")

# Print invocation about the model (only applies to scripts, not ONNX files)
if not self.extension == ".onnx":
if not (self.extension == ".onnx" or self.extension == "_state.yaml"):
if self.depth == 0 and multiple_unique_invocations:
if not model_visited:
printing.logn(f"{self.indent}{self.name}")
Expand All @@ -121,7 +121,7 @@ def _print_heading(
self.skip.model_name = True

def _print_location(self):
if self.skip.location:
if self.skip.location or self.file == "":
return

if self.depth == 0:
Expand All @@ -133,7 +133,7 @@ def _print_location(self):
self.skip.location = True

def _print_parameters(self):
if self.skip.parameters:
if self.skip.parameters or self.params is None:
return

# Display number of parameters and size
Expand Down Expand Up @@ -163,7 +163,7 @@ def _print_unique_input_shape(
self.skip.unique_input_shape = True

def _print_input_shape(self):
if self.skip.input_shape:
if self.skip.input_shape or self.input_shapes is None:
return

# Prepare input shape to be printed
Expand Down
68 changes: 31 additions & 37 deletions src/turnkeyml/files_api.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import time
import os
import copy
import glob
from datetime import datetime
from typing import List, Dict, Optional, Union
import git
import turnkeyml.common.printing as printing
Expand All @@ -11,8 +9,7 @@
import turnkeyml.cli.spawn as spawn
import turnkeyml.common.filesystem as fs
import turnkeyml.common.labels as labels_library
import turnkeyml.common.build as build
from turnkeyml.sequence.build_api import build_model
from turnkeyml.state import State

# The licensing for tqdm is confusing. Pending a legal scan,
# the following code provides tqdm to users who have installed
Expand Down Expand Up @@ -58,6 +55,18 @@ def evaluate_files(
timeout: Optional[int] = None,
sequence: Union[Dict, Sequence] = None,
):
"""
Args:
sequence: the build tools and their arguments used to build the model.
build_name: Unique name for the model that will be
used to store the ONNX file and build state on disk. Defaults to the
name of the file that calls build_model().
cache_dir: Directory to use as the cache for this build. Output files
from this build will be stored at cache_dir/build_name/
Defaults to the current working directory, but we recommend setting it to
an absolute path of your choosing.
lean_cache: delete build artifacts after the build has completed.
"""

# Replace .txt files with the models listed inside them
input_files = unpack_txt_inputs(input_files)
Expand Down Expand Up @@ -152,8 +161,10 @@ def evaluate_files(

# Skip a file if the required_labels are not a subset of the script_labels.
if labels:
# Labels argument is not supported for ONNX files
if file_path_absolute.endswith(".onnx"):
# Labels argument is not supported for ONNX files or cached builds
if file_path_absolute.endswith(".onnx") or file_path_absolute.endswith(
".yaml"
):
raise ValueError(
"The labels argument is not supported for .onnx files, got",
file_path_absolute,
Expand All @@ -179,31 +190,18 @@ def evaluate_files(
first_tool_args.append("--input")
first_tool_args.append(file_path_encoded)

# Create a build directory and stats file in the cache
fs.make_build_dir(cache_dir, build_name)
stats = fs.Stats(cache_dir, build_name)

# Save the system information used for this build
system_info = build.get_system_info()
stats.save_stat(
fs.Keys.SYSTEM_INFO,
system_info,
)
# Collection of statistics that the sequence instance should save
# to the stats file
stats_to_save = {}

# Save lables info
if fs.Keys.AUTHOR in file_labels:
stats.save_stat(fs.Keys.AUTHOR, file_labels[fs.Keys.AUTHOR][0])
stats_to_save[fs.Keys.AUTHOR] = file_labels[fs.Keys.AUTHOR][0]
if fs.Keys.TASK in file_labels:
stats.save_stat(fs.Keys.TASK, file_labels[fs.Keys.TASK][0])
stats_to_save[fs.Keys.TASK] = file_labels[fs.Keys.TASK][0]

# Save all of the lables in one place
stats.save_stat(fs.Keys.LABELS, file_labels)

# Save a timestamp so that we know the order of builds within a cache
stats.save_stat(
fs.Keys.TIMESTAMP,
datetime.now(),
)
stats_to_save[fs.Keys.LABELS] = file_labels

# If the input script is a built-in TurnkeyML model, make a note of
# which one
Expand All @@ -221,21 +219,17 @@ def evaluate_files(
fs.MODELS_DIR,
f"https://github.com/onnx/turnkeyml/tree/{git_hash}/models",
).replace("\\", "/")
stats.save_stat(fs.Keys.MODEL_SCRIPT, relative_path)
stats_to_save[fs.Keys.MODEL_SCRIPT] = relative_path

# Indicate that the build is running. If the build fails for any reason,
# we will try to catch the exception and note it in the stats.
# If a concluded build still has a status of "running", this means
# there was an uncaught exception.
stats.save_stat(fs.Keys.BUILD_STATUS, build.FunctionStatus.INCOMPLETE)

build_model(
build_name=build_name,
model=file_path_absolute,
sequence=sequence,
state = State(
cache_dir=cache_dir,
rebuild="always",
build_name=build_name,
sequence_info=sequence.info,
)
sequence.launch(
state,
lean_cache=lean_cache,
stats_to_save=stats_to_save,
)

# Wait until all the Slurm jobs are done
Expand Down
1 change: 1 addition & 0 deletions src/turnkeyml/run/benchmark_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def run(
rt_args: Optional[str] = None,
):

# raise Exception("lmao!")
selected_runtime = apply_default_runtime(device, runtime)

# Get the default part and config by providing the Device class with
Expand Down
85 changes: 0 additions & 85 deletions src/turnkeyml/sequence/build_api.py

This file was deleted.

Loading

0 comments on commit fd53f71

Please sign in to comment.