Skip to content

Commit

Permalink
Stop saving labels files in the cache (#54)
Browse files Browse the repository at this point in the history
Signed-off-by: Jeremy Fowers <[email protected]>
  • Loading branch information
jeremyfowers committed Dec 6, 2023
1 parent 7438748 commit f25f1af
Show file tree
Hide file tree
Showing 6 changed files with 17 additions and 46 deletions.
2 changes: 1 addition & 1 deletion models/readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ Example:
# labels: author::google test_group::daily,monthly
```
Labels are saved in your cache directory and can later be retrieved using the function `turnkey.common.labels.load_from_cache()`, which receives the `cache_dir` and `build_name` as inputs and returns the labels as a dictionary.
Labels are saved in your cache directory in the `turnkey_stats.yaml` file under the "labels" key.

### Parameters

Expand Down
12 changes: 5 additions & 7 deletions src/turnkeyml/analyze/script.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,11 +138,6 @@ def explore_invocation(
inputs[all_args[i]] = args[i]
invocation_info.inputs = inputs

# Save model labels
if model_info.model_type != build.ModelType.ONNX_FILE:
tracer_args.labels["class"] = [f"{type(model_info.model).__name__}"]
labels.save_to_cache(tracer_args.cache_dir, build_name, tracer_args.labels)

# If the user has not provided a specific runtime, select the runtime
# based on the device provided.
if tracer_args.runtime is None:
Expand Down Expand Up @@ -182,13 +177,16 @@ def explore_invocation(
fs.Keys.PARAMETERS,
model_info.params,
)
if model_info.model_type != build.ModelType.ONNX_FILE:
stats.save_stat(fs.Keys.CLASS, type(model_info.model).__name__)
if fs.Keys.AUTHOR in tracer_args.labels:
stats.save_stat(fs.Keys.AUTHOR, tracer_args.labels[fs.Keys.AUTHOR][0])
if fs.Keys.CLASS in tracer_args.labels:
stats.save_stat(fs.Keys.CLASS, tracer_args.labels[fs.Keys.CLASS][0])
if fs.Keys.TASK in tracer_args.labels:
stats.save_stat(fs.Keys.TASK, tracer_args.labels[fs.Keys.TASK][0])

# Save all of the lables in one place
stats.save_stat(fs.Keys.LABELS, tracer_args.labels)

# If the input script is a built-in TurnkeyML model, make a note of
# which one
if os.path.abspath(fs.MODELS_DIR) in os.path.abspath(tracer_args.input):
Expand Down
2 changes: 2 additions & 0 deletions src/turnkeyml/common/filesystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,8 @@ class Keys:
MODEL_NAME = "model_name"
# References the per-build stats section
BUILDS = "builds"
# Catch-all for storing a file's labels
LABELS = "labels"
# Author of the model
AUTHOR = "author"
# Class type of the model
Expand Down
31 changes: 0 additions & 31 deletions src/turnkeyml/common/labels.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import os
from typing import Dict, List
import turnkeyml.common.printing as printing

Expand Down Expand Up @@ -44,36 +43,6 @@ def load_from_file(file_path: str) -> Dict[str, List[str]]:
return {}


def load_from_cache(cache_dir: str, build_name: str) -> Dict[str, List[str]]:
"""
Loads labels from the cache directory
"""
# Open file
file_path = os.path.join(cache_dir, "labels", f"{build_name}.txt")
with open(file_path, encoding="utf-8") as f:
first_line = f.readline()

# Return label dict
label_list = first_line.replace("\n", "").split(" ")
return to_dict(label_list)


def save_to_cache(cache_dir: str, build_name: str, label_dict: Dict[str, List[str]]):
"""
Save labels as a stand-alone file as part of the cache directory
"""
labels_list = [f"{k}::{','.join(label_dict[k])}" for k in label_dict.keys()]

# Create labels folder if it doesn't exist
labels_dir = os.path.join(cache_dir, "labels")
os.makedirs(labels_dir, exist_ok=True)

# Save labels to cache
file_path = os.path.join(labels_dir, f"{build_name}.txt")
with open(file_path, "w", encoding="utf8") as fp:
fp.write(" ".join(labels_list))


def is_subset(label_dict_a: Dict[str, List[str]], label_dict_b: Dict[str, List[str]]):
"""
This function returns True if label_dict_a is a subset of label_dict_b.
Expand Down
9 changes: 5 additions & 4 deletions test/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,7 @@
# filesystem access

test_scripts_dot_py = {
"linear_pytorch.py": """
# labels: test_group::selftest license::mit framework::pytorch tags::selftest,small
"linear_pytorch.py": """# labels: test_group::selftest license::mit framework::pytorch tags::selftest,small
import torch
import argparse
Expand Down Expand Up @@ -235,8 +234,10 @@ def test_05_cache(self):
]
)
build_name = f"linear_pytorch_{model_hash}"
labels_found = labels.load_from_cache(cache_dir, build_name) != {}
assert cache_is_lean(cache_dir, build_name) and labels_found
labels_found = filesystem.Stats(cache_dir, build_name).stats[
filesystem.Keys.LABELS
]
assert cache_is_lean(cache_dir, build_name) and labels_found != {}, labels_found

def test_06_generic_args(self):
output = run_cli(
Expand Down
7 changes: 4 additions & 3 deletions test/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,9 +311,10 @@ def test_021_cli_report(self):
]
linear_summary = summary[1]
assert len(summary) == len(test_scripts)
assert all(
elem in linear_summary for elem in expected_cols
), f"Looked for each of {expected_cols} in {linear_summary.keys()}"
for elem in expected_cols:
assert (
elem in linear_summary
), f"Couldn't find expected key {elem} in results spreadsheet"

# Check whether all rows we expect to be populated are actually populated
assert (
Expand Down

0 comments on commit f25f1af

Please sign in to comment.