Skip to content

Commit

Permalink
Further refactoring.
Browse files Browse the repository at this point in the history
  • Loading branch information
jeremyfowers committed Dec 2, 2023
1 parent 0f0ac36 commit ffa690c
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 115 deletions.
207 changes: 96 additions & 111 deletions src/turnkeyml/analyze/script.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,15 @@ def _store_traceback(invocation_info: util.UniqueInvocationInfo):
invocation_info.status_message = " ".join(invocation_info.status_message.split())


def set_status_on_exception(build_state: build.State, stats: fs.Stats):
# We get `state` when the build tool succeeds, so we can use that to identify
# whether the exception was thrown during build or benchmark
if not build_state:
stats.add_build_stat(fs.Keys.BUILD_STATUS, fs.FunctionStatus.FAILED)
else:
stats.add_build_stat(fs.Keys.BENCHMARK_STATUS, fs.FunctionStatus.FAILED)


def explore_invocation(
model_inputs: dict,
model_info: util.ModelInfo,
Expand Down Expand Up @@ -173,10 +182,6 @@ def explore_invocation(
invocation_info.stats = stats

# Stats that apply to the model, regardless of build
stats.save_stat(
fs.Keys.ANALYSIS_STATUS,
fs.FunctionStatus.SUCCESSFUL,
)
stats.save_stat(
fs.Keys.HASH,
model_info.hash,
Expand All @@ -196,6 +201,13 @@ def explore_invocation(
if fs.Keys.TASK in tracer_args.labels:
stats.save_stat(fs.Keys.TASK, tracer_args.labels[fs.Keys.TASK][0])

# Save the system information used for this evaluation
system_info = build.get_system_info()
stats.save_stat(
fs.Keys.SYSTEM_INFO,
system_info,
)

# If the input script is a built-in TurnkeyML model, make a note of
# which one
if os.path.abspath(fs.MODELS_DIR) in os.path.abspath(tracer_args.input):
Expand Down Expand Up @@ -228,86 +240,103 @@ def explore_invocation(
tracer_args.iterations,
)

if model_info.model_type == build.ModelType.PYTORCH_COMPILED:
invocation_info.status_message = (
"Skipping model compiled using torch.compile(). "
"turnkey requires models to be in eager mode "
"(regardless of what runtime you have selected)."
)
invocation_info.status_message_color = printing.Colors.WARNING

return

build_state = None
perf = None
try:
if model_info.model_type == build.ModelType.PYTORCH_COMPILED:
invocation_info.status_message = (
"Skipping model compiled using torch.compile(). "
"turnkey requires models to be in eager mode "
"(regardless of what runtime you have selected)."
# Run the build tool (if needed by the runtime)
if runtime_info["build_required"]:
# Indicate that the build is running. If the build fails for any reason,
# we will try to catch the exception and note it in the stats.
# If a concluded build still has a status of "running", this means
# there was an uncaught exception.
stats.add_build_stat(fs.Keys.BUILD_STATUS, fs.FunctionStatus.RUNNING)

build_state = build_model(
model=model_info.model,
inputs=inputs,
stats_id=stats_id,
build_name=build_name,
cache_dir=tracer_args.cache_dir,
rebuild=tracer_args.rebuild,
sequence=sequence_selected,
onnx_opset=tracer_args.onnx_opset,
device=tracer_args.device,
)
invocation_info.status_message_color = printing.Colors.WARNING
else:
if runtime_info["build_required"]:
# Indicate that the build is running. If the build fails for any reason,
# we will try to catch the exception and note it in the stats.
# If a concluded build still has a status of "running", this means
# there was an uncaught exception.
stats.add_build_stat(fs.Keys.BUILD_STATUS, fs.FunctionStatus.RUNNING)

state = build_model(
model=model_info.model,
inputs=inputs,
stats_id=stats_id,
build_name=build_name,
cache_dir=tracer_args.cache_dir,
rebuild=tracer_args.rebuild,
sequence=sequence_selected,
onnx_opset=tracer_args.onnx_opset,
device=tracer_args.device,
)

stats.add_build_stat(fs.Keys.BUILD_STATUS, fs.FunctionStatus.SUCCESSFUL)
stats.add_build_stat(fs.Keys.BUILD_STATUS, fs.FunctionStatus.SUCCESSFUL)

model_to_benchmark = state.results[0]
else:
model_to_benchmark = model_info.model
model_to_benchmark = build_state.results[0]

if Action.BENCHMARK in tracer_args.actions:
stats.add_build_stat(
fs.Keys.BENCHMARK_STATUS, fs.FunctionStatus.RUNNING
)
# Analyze the onnx file (if any) and save statistics
util.analyze_onnx(
build_name=build_name,
cache_dir=tracer_args.cache_dir,
stats=stats,
)
else:
model_to_benchmark = model_info.model

if tracer_args.rt_args is None:
rt_args_to_use = {}
else:
rt_args_to_use = tracer_args.rt_args

model_handle = runtime_info["RuntimeClass"](
cache_dir=tracer_args.cache_dir,
build_name=build_name,
stats=stats,
iterations=tracer_args.iterations,
model=model_to_benchmark,
inputs=inputs,
device_type=tracer_args.device,
runtime=selected_runtime,
**rt_args_to_use,
)
perf = model_handle.benchmark()
# Run the benchmark tool (if requested by the user)
if Action.BENCHMARK in tracer_args.actions:
if tracer_args.rt_args is None:
rt_args_to_use = {}
else:
rt_args_to_use = tracer_args.rt_args

stats.add_build_stat(fs.Keys.BENCHMARK_STATUS, fs.FunctionStatus.RUNNING)

model_handle = runtime_info["RuntimeClass"](
cache_dir=tracer_args.cache_dir,
build_name=build_name,
stats=stats,
iterations=tracer_args.iterations,
model=model_to_benchmark,
inputs=inputs,
device_type=tracer_args.device,
runtime=selected_runtime,
**rt_args_to_use,
)
perf = model_handle.benchmark()

for key, value in vars(perf).items():
stats.add_build_stat(
fs.Keys.BENCHMARK_STATUS, fs.FunctionStatus.SUCCESSFUL
key=key,
value=value,
)

invocation_info.status_message = "Model successfully benchmarked!"
invocation_info.performance = perf
invocation_info.status_message_color = printing.Colors.OKGREEN
else:
invocation_info.status_message = "Model successfully built!"
invocation_info.status_message_color = printing.Colors.OKGREEN
stats.add_build_stat(fs.Keys.BENCHMARK_STATUS, fs.FunctionStatus.SUCCESSFUL)

invocation_info.status_message = "Model successfully benchmarked!"
invocation_info.performance = perf
invocation_info.status_message_color = printing.Colors.OKGREEN
else:
invocation_info.status_message = "Model successfully built!"
invocation_info.status_message_color = printing.Colors.OKGREEN

except exp.StageError as e:
invocation_info.status_message = f"Build Error: {e}"
invocation_info.status_message_color = printing.Colors.WARNING

stats.add_build_stat(fs.Keys.BUILD_STATUS, fs.FunctionStatus.FAILED)
set_status_on_exception(build_state, stats)

_store_traceback(invocation_info)

except exp.SkipBuild:
# SkipBuild is an exception that the build_model() API will raise
# when it is skipping a previously-failed build when rebuild=never is set

# NOTE: skipping a build should never update build or benchmark status

invocation_info.status_message = (
"Build intentionally skipped because rebuild=never"
)
Expand All @@ -318,13 +347,15 @@ def explore_invocation(
# illegal. In that case we want to halt execution so that users can
# fix their arguments.

set_status_on_exception(build_state, stats)

raise e

except exp.Error as e:
invocation_info.status_message = f"Error: {e}."
invocation_info.status_message_color = printing.Colors.WARNING

stats.add_build_stat(fs.Keys.BENCHMARK_STATUS, fs.FunctionStatus.FAILED)
set_status_on_exception(build_state, stats)

_store_traceback(invocation_info)

Expand All @@ -334,61 +365,15 @@ def explore_invocation(
invocation_info.status_message = f"Unknown turnkey error: {e}"
invocation_info.status_message_color = printing.Colors.WARNING

stats.add_build_stat(fs.Keys.BENCHMARK_STATUS, fs.FunctionStatus.FAILED)
set_status_on_exception(build_state, stats)

_store_traceback(invocation_info)

finally:
# Ensure that stdout/stderr is not being forwarded before updating status
util.stop_logger_forward()

system_info = build.get_system_info()
stats.save_stat(
fs.Keys.SYSTEM_INFO,
system_info,
)
if model_info.model_type != build.ModelType.PYTORCH_COMPILED:
# We have this if-block because torch-compiled model instances
# are not legal input to this function. So when we encounter one,
# we want to exit the function as quickly as possible, without
# doing any of the logic that follows this comment.

# ONNX stats that we want to save into the build's turnkey_stats.yaml file
# so that they can be easily accessed by the report command later
if fs.Keys.ONNX_FILE in stats.build_stats.keys():
# Just in case the ONNX file was generated on a different machine:
# strip the state's cache dir, then prepend the current cache dir
final_onnx_file = fs.rebase_cache_dir(
stats.build_stats[fs.Keys.ONNX_FILE],
build_name,
tracer_args.cache_dir,
)

onnx_ops_counter = util.get_onnx_ops_list(final_onnx_file)
onnx_model_info = util.populate_onnx_model_info(final_onnx_file)
onnx_input_dimensions = util.onnx_input_dimensions(final_onnx_file)

stats.save_stat(
fs.Keys.ONNX_OPS_COUNTER,
onnx_ops_counter,
)
stats.save_stat(
fs.Keys.ONNX_MODEL_INFO,
onnx_model_info,
)
stats.save_stat(
fs.Keys.ONNX_INPUT_DIMENSIONS,
onnx_input_dimensions,
)

if perf:
for key, value in vars(perf).items():
stats.add_build_stat(
key=key,
value=value,
)

status.update(tracer_args.models_found, build_name, tracer_args.cache_dir)
status.update(tracer_args.models_found, build_name, tracer_args.cache_dir)

if tracer_args.lean_cache:
printing.log_info("Removing build artifacts...")
Expand Down
34 changes: 32 additions & 2 deletions src/turnkeyml/analyze/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from turnkeyml.common import printing
import turnkeyml.common.build as build
from turnkeyml.common.performance import MeasuredPerformance
from turnkeyml.common.filesystem import Stats
import turnkeyml.common.filesystem as fs


class AnalysisException(Exception):
Expand Down Expand Up @@ -37,7 +37,7 @@ class UniqueInvocationInfo:
status_message_color: printing.Colors = printing.Colors.ENDC
traceback_message_color: printing.Colors = printing.Colors.FAIL
stats_keys: Optional[List[str]] = None
stats: Stats = None
stats: fs.Stats = None


@dataclass
Expand Down Expand Up @@ -162,3 +162,33 @@ def stop_logger_forward() -> None:
sys.stdout = sys.stdout.terminal
if hasattr(sys.stderr, "terminal_err"):
sys.stderr = sys.stderr.terminal_err


def analyze_onnx(build_name: str, cache_dir: str, stats: fs.Stats):
# ONNX stats that we want to save into the build's turnkey_stats.yaml file
# so that they can be easily accessed by the report command later
if fs.Keys.ONNX_FILE in stats.build_stats.keys():
# Just in case the ONNX file was generated on a different machine:
# strip the state's cache dir, then prepend the current cache dir
final_onnx_file = fs.rebase_cache_dir(
stats.build_stats[fs.Keys.ONNX_FILE],
build_name,
cache_dir,
)

onnx_ops_counter = get_onnx_ops_list(final_onnx_file)
onnx_model_info = populate_onnx_model_info(final_onnx_file)
input_dimensions = onnx_input_dimensions(final_onnx_file)

stats.save_stat(
fs.Keys.ONNX_OPS_COUNTER,
onnx_ops_counter,
)
stats.save_stat(
fs.Keys.ONNX_MODEL_INFO,
onnx_model_info,
)
stats.save_stat(
fs.Keys.ONNX_INPUT_DIMENSIONS,
input_dimensions,
)
2 changes: 0 additions & 2 deletions src/turnkeyml/common/filesystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,8 +352,6 @@ class Keys:
SYSTEM_INFO = "system_info"
# Path to the built-in model script used as input
MODEL_SCRIPT = "builtin_model_script"
# Indicates status of the most recent analysis tool run: FunctionStatus
ANALYSIS_STATUS = "analysis_status"
# Indicates status of the most recent build tool run: FunctionStatus
BUILD_STATUS = "build_status"
# Indicates status of the most recent benchmark tool run: FunctionStatus
Expand Down

0 comments on commit ffa690c

Please sign in to comment.