From ffa690ca3d10c217d975c0c521f6a6769a0d3afb Mon Sep 17 00:00:00 2001 From: Jeremy Fowers Date: Sat, 2 Dec 2023 15:30:51 -0500 Subject: [PATCH] Further refactoring. --- src/turnkeyml/analyze/script.py | 207 +++++++++++++---------------- src/turnkeyml/analyze/util.py | 34 ++++- src/turnkeyml/common/filesystem.py | 2 - 3 files changed, 128 insertions(+), 115 deletions(-) diff --git a/src/turnkeyml/analyze/script.py b/src/turnkeyml/analyze/script.py index 24a49af..87d2ff5 100644 --- a/src/turnkeyml/analyze/script.py +++ b/src/turnkeyml/analyze/script.py @@ -84,6 +84,15 @@ def _store_traceback(invocation_info: util.UniqueInvocationInfo): invocation_info.status_message = " ".join(invocation_info.status_message.split()) +def set_status_on_exception(build_state: build.State, stats: fs.Stats): + # We get `state` when the build tool succeeds, so we can use that to identify + # whether the exception was thrown during build or benchmark + if not build_state: + stats.add_build_stat(fs.Keys.BUILD_STATUS, fs.FunctionStatus.FAILED) + else: + stats.add_build_stat(fs.Keys.BENCHMARK_STATUS, fs.FunctionStatus.FAILED) + + def explore_invocation( model_inputs: dict, model_info: util.ModelInfo, @@ -173,10 +182,6 @@ def explore_invocation( invocation_info.stats = stats # Stats that apply to the model, regardless of build - stats.save_stat( - fs.Keys.ANALYSIS_STATUS, - fs.FunctionStatus.SUCCESSFUL, - ) stats.save_stat( fs.Keys.HASH, model_info.hash, @@ -196,6 +201,13 @@ def explore_invocation( if fs.Keys.TASK in tracer_args.labels: stats.save_stat(fs.Keys.TASK, tracer_args.labels[fs.Keys.TASK][0]) + # Save the system information used for this evaluation + system_info = build.get_system_info() + stats.save_stat( + fs.Keys.SYSTEM_INFO, + system_info, + ) + # If the input script is a built-in TurnkeyML model, make a note of # which one if os.path.abspath(fs.MODELS_DIR) in os.path.abspath(tracer_args.input): @@ -228,86 +240,103 @@ def explore_invocation( tracer_args.iterations, ) + if model_info.model_type == build.ModelType.PYTORCH_COMPILED: + invocation_info.status_message = ( + "Skipping model compiled using torch.compile(). " + "turnkey requires models to be in eager mode " + "(regardless of what runtime you have selected)." + ) + invocation_info.status_message_color = printing.Colors.WARNING + + return + + build_state = None perf = None try: - if model_info.model_type == build.ModelType.PYTORCH_COMPILED: - invocation_info.status_message = ( - "Skipping model compiled using torch.compile(). " - "turnkey requires models to be in eager mode " - "(regardless of what runtime you have selected)." + # Run the build tool (if needed by the runtime) + if runtime_info["build_required"]: + # Indicate that the build is running. If the build fails for any reason, + # we will try to catch the exception and note it in the stats. + # If a concluded build still has a status of "running", this means + # there was an uncaught exception. + stats.add_build_stat(fs.Keys.BUILD_STATUS, fs.FunctionStatus.RUNNING) + + build_state = build_model( + model=model_info.model, + inputs=inputs, + stats_id=stats_id, + build_name=build_name, + cache_dir=tracer_args.cache_dir, + rebuild=tracer_args.rebuild, + sequence=sequence_selected, + onnx_opset=tracer_args.onnx_opset, + device=tracer_args.device, ) - invocation_info.status_message_color = printing.Colors.WARNING - else: - if runtime_info["build_required"]: - # Indicate that the build is running. If the build fails for any reason, - # we will try to catch the exception and note it in the stats. - # If a concluded build still has a status of "running", this means - # there was an uncaught exception. - stats.add_build_stat(fs.Keys.BUILD_STATUS, fs.FunctionStatus.RUNNING) - - state = build_model( - model=model_info.model, - inputs=inputs, - stats_id=stats_id, - build_name=build_name, - cache_dir=tracer_args.cache_dir, - rebuild=tracer_args.rebuild, - sequence=sequence_selected, - onnx_opset=tracer_args.onnx_opset, - device=tracer_args.device, - ) - stats.add_build_stat(fs.Keys.BUILD_STATUS, fs.FunctionStatus.SUCCESSFUL) + stats.add_build_stat(fs.Keys.BUILD_STATUS, fs.FunctionStatus.SUCCESSFUL) - model_to_benchmark = state.results[0] - else: - model_to_benchmark = model_info.model + model_to_benchmark = build_state.results[0] - if Action.BENCHMARK in tracer_args.actions: - stats.add_build_stat( - fs.Keys.BENCHMARK_STATUS, fs.FunctionStatus.RUNNING - ) + # Analyze the onnx file (if any) and save statistics + util.analyze_onnx( + build_name=build_name, + cache_dir=tracer_args.cache_dir, + stats=stats, + ) + else: + model_to_benchmark = model_info.model - if tracer_args.rt_args is None: - rt_args_to_use = {} - else: - rt_args_to_use = tracer_args.rt_args - - model_handle = runtime_info["RuntimeClass"]( - cache_dir=tracer_args.cache_dir, - build_name=build_name, - stats=stats, - iterations=tracer_args.iterations, - model=model_to_benchmark, - inputs=inputs, - device_type=tracer_args.device, - runtime=selected_runtime, - **rt_args_to_use, - ) - perf = model_handle.benchmark() + # Run the benchmark tool (if requested by the user) + if Action.BENCHMARK in tracer_args.actions: + if tracer_args.rt_args is None: + rt_args_to_use = {} + else: + rt_args_to_use = tracer_args.rt_args + + stats.add_build_stat(fs.Keys.BENCHMARK_STATUS, fs.FunctionStatus.RUNNING) + + model_handle = runtime_info["RuntimeClass"]( + cache_dir=tracer_args.cache_dir, + build_name=build_name, + stats=stats, + iterations=tracer_args.iterations, + model=model_to_benchmark, + inputs=inputs, + device_type=tracer_args.device, + runtime=selected_runtime, + **rt_args_to_use, + ) + perf = model_handle.benchmark() + for key, value in vars(perf).items(): stats.add_build_stat( - fs.Keys.BENCHMARK_STATUS, fs.FunctionStatus.SUCCESSFUL + key=key, + value=value, ) - invocation_info.status_message = "Model successfully benchmarked!" - invocation_info.performance = perf - invocation_info.status_message_color = printing.Colors.OKGREEN - else: - invocation_info.status_message = "Model successfully built!" - invocation_info.status_message_color = printing.Colors.OKGREEN + stats.add_build_stat(fs.Keys.BENCHMARK_STATUS, fs.FunctionStatus.SUCCESSFUL) + + invocation_info.status_message = "Model successfully benchmarked!" + invocation_info.performance = perf + invocation_info.status_message_color = printing.Colors.OKGREEN + else: + invocation_info.status_message = "Model successfully built!" + invocation_info.status_message_color = printing.Colors.OKGREEN except exp.StageError as e: invocation_info.status_message = f"Build Error: {e}" invocation_info.status_message_color = printing.Colors.WARNING - stats.add_build_stat(fs.Keys.BUILD_STATUS, fs.FunctionStatus.FAILED) + set_status_on_exception(build_state, stats) _store_traceback(invocation_info) except exp.SkipBuild: # SkipBuild is an exception that the build_model() API will raise # when it is skipping a previously-failed build when rebuild=never is set + + # NOTE: skipping a build should never update build or benchmark status + invocation_info.status_message = ( "Build intentionally skipped because rebuild=never" ) @@ -318,13 +347,15 @@ def explore_invocation( # illegal. In that case we want to halt execution so that users can # fix their arguments. + set_status_on_exception(build_state, stats) + raise e except exp.Error as e: invocation_info.status_message = f"Error: {e}." invocation_info.status_message_color = printing.Colors.WARNING - stats.add_build_stat(fs.Keys.BENCHMARK_STATUS, fs.FunctionStatus.FAILED) + set_status_on_exception(build_state, stats) _store_traceback(invocation_info) @@ -334,7 +365,7 @@ def explore_invocation( invocation_info.status_message = f"Unknown turnkey error: {e}" invocation_info.status_message_color = printing.Colors.WARNING - stats.add_build_stat(fs.Keys.BENCHMARK_STATUS, fs.FunctionStatus.FAILED) + set_status_on_exception(build_state, stats) _store_traceback(invocation_info) @@ -342,53 +373,7 @@ def explore_invocation( # Ensure that stdout/stderr is not being forwarded before updating status util.stop_logger_forward() - system_info = build.get_system_info() - stats.save_stat( - fs.Keys.SYSTEM_INFO, - system_info, - ) - if model_info.model_type != build.ModelType.PYTORCH_COMPILED: - # We have this if-block because torch-compiled model instances - # are not legal input to this function. So when we encounter one, - # we want to exit the function as quickly as possible, without - # doing any of the logic that follows this comment. - - # ONNX stats that we want to save into the build's turnkey_stats.yaml file - # so that they can be easily accessed by the report command later - if fs.Keys.ONNX_FILE in stats.build_stats.keys(): - # Just in case the ONNX file was generated on a different machine: - # strip the state's cache dir, then prepend the current cache dir - final_onnx_file = fs.rebase_cache_dir( - stats.build_stats[fs.Keys.ONNX_FILE], - build_name, - tracer_args.cache_dir, - ) - - onnx_ops_counter = util.get_onnx_ops_list(final_onnx_file) - onnx_model_info = util.populate_onnx_model_info(final_onnx_file) - onnx_input_dimensions = util.onnx_input_dimensions(final_onnx_file) - - stats.save_stat( - fs.Keys.ONNX_OPS_COUNTER, - onnx_ops_counter, - ) - stats.save_stat( - fs.Keys.ONNX_MODEL_INFO, - onnx_model_info, - ) - stats.save_stat( - fs.Keys.ONNX_INPUT_DIMENSIONS, - onnx_input_dimensions, - ) - - if perf: - for key, value in vars(perf).items(): - stats.add_build_stat( - key=key, - value=value, - ) - - status.update(tracer_args.models_found, build_name, tracer_args.cache_dir) + status.update(tracer_args.models_found, build_name, tracer_args.cache_dir) if tracer_args.lean_cache: printing.log_info("Removing build artifacts...") diff --git a/src/turnkeyml/analyze/util.py b/src/turnkeyml/analyze/util.py index 27e9659..1d3a284 100644 --- a/src/turnkeyml/analyze/util.py +++ b/src/turnkeyml/analyze/util.py @@ -8,7 +8,7 @@ from turnkeyml.common import printing import turnkeyml.common.build as build from turnkeyml.common.performance import MeasuredPerformance -from turnkeyml.common.filesystem import Stats +import turnkeyml.common.filesystem as fs class AnalysisException(Exception): @@ -37,7 +37,7 @@ class UniqueInvocationInfo: status_message_color: printing.Colors = printing.Colors.ENDC traceback_message_color: printing.Colors = printing.Colors.FAIL stats_keys: Optional[List[str]] = None - stats: Stats = None + stats: fs.Stats = None @dataclass @@ -162,3 +162,33 @@ def stop_logger_forward() -> None: sys.stdout = sys.stdout.terminal if hasattr(sys.stderr, "terminal_err"): sys.stderr = sys.stderr.terminal_err + + +def analyze_onnx(build_name: str, cache_dir: str, stats: fs.Stats): + # ONNX stats that we want to save into the build's turnkey_stats.yaml file + # so that they can be easily accessed by the report command later + if fs.Keys.ONNX_FILE in stats.build_stats.keys(): + # Just in case the ONNX file was generated on a different machine: + # strip the state's cache dir, then prepend the current cache dir + final_onnx_file = fs.rebase_cache_dir( + stats.build_stats[fs.Keys.ONNX_FILE], + build_name, + cache_dir, + ) + + onnx_ops_counter = get_onnx_ops_list(final_onnx_file) + onnx_model_info = populate_onnx_model_info(final_onnx_file) + input_dimensions = onnx_input_dimensions(final_onnx_file) + + stats.save_stat( + fs.Keys.ONNX_OPS_COUNTER, + onnx_ops_counter, + ) + stats.save_stat( + fs.Keys.ONNX_MODEL_INFO, + onnx_model_info, + ) + stats.save_stat( + fs.Keys.ONNX_INPUT_DIMENSIONS, + input_dimensions, + ) diff --git a/src/turnkeyml/common/filesystem.py b/src/turnkeyml/common/filesystem.py index 5396a6d..1c8a875 100644 --- a/src/turnkeyml/common/filesystem.py +++ b/src/turnkeyml/common/filesystem.py @@ -352,8 +352,6 @@ class Keys: SYSTEM_INFO = "system_info" # Path to the built-in model script used as input MODEL_SCRIPT = "builtin_model_script" - # Indicates status of the most recent analysis tool run: FunctionStatus - ANALYSIS_STATUS = "analysis_status" # Indicates status of the most recent build tool run: FunctionStatus BUILD_STATUS = "build_status" # Indicates status of the most recent benchmark tool run: FunctionStatus