From d86abcf04a77c8aa2c3d5847e27f02e23db2c740 Mon Sep 17 00:00:00 2001 From: pangyoki Date: Wed, 5 Jan 2022 12:44:57 +0000 Subject: [PATCH 1/3] add scheduling evaluation in paddle dygraph mode --- api/common/launch.py | 141 ++++++++++- api/common/main.py | 30 +++ ...paddle_dynamic_scheduling_api_benchmark.py | 224 ++++++++++++++++++ api/dynamic_scheduling_tests_v2/__init__.py | 13 + api/dynamic_scheduling_tests_v2/abs.py | 39 +++ api/dynamic_scheduling_tests_v2/run.sh | 49 ++++ 6 files changed, 495 insertions(+), 1 deletion(-) create mode 100644 api/common/paddle_dynamic_scheduling_api_benchmark.py create mode 100644 api/dynamic_scheduling_tests_v2/__init__.py create mode 100644 api/dynamic_scheduling_tests_v2/abs.py create mode 100644 api/dynamic_scheduling_tests_v2/run.sh diff --git a/api/common/launch.py b/api/common/launch.py index 623de5caf6..019638d85b 100644 --- a/api/common/launch.py +++ b/api/common/launch.py @@ -168,7 +168,133 @@ def _parse_gpu_time(self, line): return gpu_time / percent -def launch(benchmark_script, benchmark_script_args, with_nvprof=False): +class NsightRunnerForDynamicScheduling(object): + def run(self, cmd, op_type, nvprof_start_step, nvprof_end_step, backward): + stdout, exit_code = self._nsight_for_dynamic_scheduling(cmd) + if exit_code == 0: + parse_status, gpu_time = self._parse_logs( + stdout.split("\n"), op_type, nvprof_start_step, + nvprof_end_step, backward) + if parse_status: + return gpu_time + print("Running Error:\n {}".format(stdout)) + return 0.0 + + def _nsight_for_dynamic_scheduling(self, cmd): + return system.run_command( + "nsys profile -t cuda,nvtx --stats true -o tmp.qdrep --force-overwrite true {}". + format(cmd)) + + def _to_float(s): + return float(s.replace(',', '')) + + def _calculate_avg_time(l): + total_time = _to_float(l[1]) + max_time = _to_float(l[5]) + calls = _to_float(l[2]) - 1 + return (total_time - max_time) / calls + + def _parse_logs(self, logs, op_type, nvprof_start_step, nvprof_end_step, + backward): + flag_nvtx_time = False + total_step_time = 0.0 + step_count = 0 + + # 0: imperative (imperative_avg_time) + # 1: op_type (fwd_trace_op_avg_time) + # 2: op_type compute (fwd_op_compute_avg_time) + # 3: op_type_grad (bwd_trace_op_avg_time) + # 4: op_type_grad compute (bwd_op_compute_avg_time) + _scheduling_list = [ + 'imperative', op_type, op_type + ' compute', op_type + '_grad', + op_type + '_grad compute' + ] + nvtx_meta_data_dict = {} + scheduling_time_dict = {} + + for i in range(len(logs)): + line = api_param.parse_string(logs[i]) + if flag_nvtx_time: + infos = line.strip().split() + if not infos: + continue + nvtx_range_type = infos[-1] + if nvtx_range_type == 'compute' or nvtx_range_type == 'infer_shape': + nvtx_range_type = infos[-2] + ' ' + nvtx_range_type + + # step time + if nvtx_range_type.isdigit() and int( + nvtx_range_type) > nvprof_start_step and int( + nvtx_range_type) < nvprof_end_step: + step_count += 1 + step_time = _to_float(infos[1]) + total_step_time += step_time + + if nvtx_range_type in _scheduling_list: + avg_time = _calculate_avg_time(infos) + nvtx_meta_data_dict[nvtx_range_type] = avg_time + # print(nvtx_range_type + ' time: ', avg_time) + + if 'NVTX Push-Pop Range Statistics:' in line: + flag_nvtx_time = True + if step_count != 0: + nvtx_meta_data_dict['step'] = total_step_time / step_count + # print("num_step: ", step_count, " step_avg_time: ", total_step_time / step_count) + + scheduling_time_dict['step_avg_time'] = nvtx_meta_data_dict['step'] + scheduling_time_dict['imperative_avg_time'] = nvtx_meta_data_dict[ + _scheduling_list[0]] + scheduling_time_dict['fwd_trace_op_avg_time'] = nvtx_meta_data_dict[ + _scheduling_list[1]] + scheduling_time_dict['fwd_op_compute_avg_time'] = nvtx_meta_data_dict[ + _scheduling_list[2]] + scheduling_time_dict['bwd_trace_op_avg_time'] = nvtx_meta_data_dict[ + _scheduling_list[3]] + scheduling_time_dict['bwd_op_compute_avg_time'] = nvtx_meta_data_dict[ + _scheduling_list[4]] + if scheduling_time_dict['step_avg_time'] and scheduling_time_dict[ + 'imperative_avg_time']: + if not backward: + scheduling_time_dict[ + 'python_call_time'] = scheduling_time_dict[ + 'step_avg_time'] - scheduling_time_dict[ + 'imperative_avg_time'] + elif scheduling_time_dict['bwd_trace_op_avg_time']: + scheduling_time_dict[ + 'python_call_time'] = scheduling_time_dict[ + 'step_avg_time'] - scheduling_time_dict[ + 'imperative_avg_time'] - scheduling_time_dict[ + 'bwd_trace_op_avg_time'] + if scheduling_time_dict[ + 'imperative_avg_time'] and scheduling_time_dict[ + 'fwd_trace_op_avg_time']: + scheduling_time_dict[ + 'imperative_call_time'] = scheduling_time_dict[ + 'imperative_avg_time'] - scheduling_time_dict[ + 'fwd_trace_op_avg_time'] + if scheduling_time_dict[ + 'fwd_trace_op_avg_time'] and scheduling_time_dict[ + 'fwd_op_compute_avg_time']: + scheduling_time_dict[ + 'fwd_trace_op_call_time'] = scheduling_time_dict[ + 'fwd_trace_op_avg_time'] - scheduling_time_dict[ + 'fwd_op_compute_avg_time'] + if scheduling_time_dict[ + 'bwd_trace_op_avg_time'] and scheduling_time_dict[ + 'bwd_op_compute_avg_time']: + scheduling_time_dict[ + 'bwd_trace_op_call_time'] = scheduling_time_dict[ + 'bwd_trace_op_avg_time'] - scheduling_time_dict[ + 'bwd_op_compute_avg_time'] + + print(scheduling_time_dict) + return scheduling_time_dict + + +def launch(benchmark_script, + benchmark_script_args, + with_nvprof=False, + with_dynamic_scheduling=False): """ If with_nvprof is True, it will launch the following command firstly to get the gpu_time: @@ -188,10 +314,23 @@ def _set_profiler(args, value): args.append("--profiler") args.append(value) + def _split_arg_str_value(cmd, arg_name): + if arg_name not in cmd: + return None + return cmd.split("--" + arg_name)[1].strip().split()[0] + if with_nvprof: _set_profiler(benchmark_script_args, "nvprof") cmd = "{} {} {}".format(sys.executable, benchmark_script, " ".join(benchmark_script_args)) + if with_dynamic_scheduling: + runner = NsightRunnerForDynamicScheduling() + nvprof_start_step = int(_split_arg_str_value(cmd, "nvprof_start_step")) + nvprof_end_step = int(_split_arg_str_value(cmd, "nvprof_end_step")) + op_type = _split_arg_str_value(cmd, "api_name") + backward = bool(_split_arg_str_value(cmd, "backward")) + scheduling_time_dict = runner.run(cmd, op_type, nvprof_start_step, + nvprof_end_step, backward) if with_nvprof: if is_ampere_gpu(): runner = NsightRunner() diff --git a/api/common/main.py b/api/common/main.py index 3fab250920..a988b1ba75 100644 --- a/api/common/main.py +++ b/api/common/main.py @@ -99,6 +99,19 @@ def parse_args(): help='Total GPU kernel time parsed from nvprof') parser.add_argument( '--repeat', type=int, default=1, help='Iterations of Repeat running') + parser.add_argument( + '--is_dynamic_scheduling', + type=system.str2bool, + default=False, + help='Whether to calculate scheduling cost in dynamic mode [True|False]' + ) + parser.add_argument( + '--nvprof_start_step', + type=int, + default=1, + help='Start step of profile') + parser.add_argument( + '--nvprof_end_step', type=int, default=100, help='End step of profile') parser.add_argument( '--allow_adaptive_repeat', type=system.str2bool, @@ -302,6 +315,23 @@ def test_main_without_json(pd_obj=None, if pd_dy_outputs == False: sys.exit(1) + if args.is_dynamic_scheduling and _is_paddle_enabled( + args, config) and args.testing_mode == "dynamic": + assert pd_dy_obj is not None, "Paddle dynamic object is None." + print(config) + pd_dy_outputs, pd_dy_stats = pd_dy_obj.run(config, args, + feeder_adapter) + + if args.task == "speed": + pd_dy_stats["gpu_time"] = args.gpu_time + utils.print_benchmark_result( + pd_dy_stats, + log_level=args.log_level, + config_params=config.to_string()) + + if pd_dy_outputs == False: + sys.exit(1) + if args.task == "accuracy": is_run_tf = config.run_tf and args.testing_mode == "static" is_run_torch = config.run_torch and args.testing_mode == "dynamic" diff --git a/api/common/paddle_dynamic_scheduling_api_benchmark.py b/api/common/paddle_dynamic_scheduling_api_benchmark.py new file mode 100644 index 0000000000..283a7d79c4 --- /dev/null +++ b/api/common/paddle_dynamic_scheduling_api_benchmark.py @@ -0,0 +1,224 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import sys +import json +import time +import abc, six +import importlib +import numpy as np + +from common import utils +from common import api_param +from common import feeder +from common import special_op_list +from common.paddle_api_benchmark import profile_context + +try: + import paddle + from paddle.fluid import core +except Exception as e: + sys.stderr.write( + "Cannot import paddle.fluid, maybe paddle is not installed.\n") + +BEFORE_RUN = 0 +IN_RUN = 1 +AFTER_RUN = 2 + + +@six.add_metaclass(abc.ABCMeta) +class PaddleDynamicSchedulingBenchmarkBase(object): + def __init__(self): + self.name = self.__class__.__name__ + self._reset() + + @abc.abstractmethod + def build_graph(self, config=None): + pass + + def compute_flop_and_byte(self, config): + """ flop is used as a metric for op's performance and it is optional. + """ + return None, None + + def variable(self, name, shape, dtype, value=None, stop_gradient=False): + if self._status == BEFORE_RUN: + if self._feed_values is not None and value is None: + i = len(self._feed_dict) + feed_value = feeder.check_shape_and_dtype( + shape=shape, dtype=dtype, value=self._feed_values[i]) + else: + assert shape is not None + + if self._feed_spec is not None and value is None: + i = len(self._feed_dict) + range = self._feed_spec[i].get("range", None) + else: + range = None + feed_value = feeder.generate_random_data( + shape, dtype, range=range, value=value) + var = paddle.to_tensor(feed_value, stop_gradient=stop_gradient) + self._feed_dict[name] = var + else: + var = self._feed_dict[name] + return var + + @property + def backward(self): + return self._backward + + def layers(self, api_name, module_name=None, **kwargs): + def _import_func(paddle_module_name, api_name): + try: + module = importlib.import_module(paddle_module_name) + func = getattr(module, api_name) + print("Successly import %s.%s" % + (paddle_module_name, api_name)) + return func + except Exception: + print("Failed to import %s.%s" % + (paddle_module_name, api_name)) + return None + + if self._layers_function is None: + paddle_module_names = ["paddle", "paddle.nn.functional"] + if module_name is not None and module_name not in paddle_module_names: + paddle_module_names.append(module_name) + + for paddle_module_name in paddle_module_names: + func = _import_func(paddle_module_name, api_name) + if func is not None: + break + + assert func is not None, "Need to specify module_name to import %s." % api_name + self._layers_function = func + + result = self._layers_function(**kwargs) + return result + + def append_gradients(self, targets, inputs): + if not isinstance(targets, list): + if len(self._ones_like_targets) == 0: + ones_like_targets = paddle.ones_like(targets) + self._ones_like_targets.append(ones_like_targets) + else: + ones_like_targets = self._ones_like_targets[0] + else: + ones_like_targets = None + gradients = paddle.grad( + outputs=targets, inputs=inputs, grad_outputs=ones_like_targets) + self._backward = True + if isinstance(gradients, list): + for grad in gradients: + self.fetch_list.append(grad) + else: + self.fetch_list.append(gradients) + + def run_impl(self, + use_gpu, + config, + repeat=1, + nvprof_start_step=10, + nvprof_end_step=100, + profiler="none", + feeder_adapter=None): + def _run_main_iter(): + self.build_graph(config=config) + if use_gpu: + paddle.fluid._cuda_synchronize(paddle.fluid.CUDAPlace(0)) + + outputs = None + if self._need_fetch: + outputs = [] + for var in self.fetch_list: + if isinstance(var, np.ndarray): + outputs.append(var) + else: + outputs.append(var.numpy()) + return outputs + + # warmup run + _run_main_iter() + + runtimes = [] + fetches = [] + + self._status = IN_RUN + with profile_context(self.name, use_gpu, profiler): + for i in range(nvprof_end_step + 1): + if i == nvprof_start_step: + core.nvprof_start() + core.nvprof_enable_record_event() + core.nvprof_nvtx_push(str(i)) + if i == nvprof_end_step: + paddle.fluid._cuda_synchronize(paddle.fluid.CUDAPlace(0)) + core.nvprof_nvtx_pop() + core.nvprof_stop() + sys.exit() + if i > nvprof_start_step and i < nvprof_end_step: + core.nvprof_nvtx_pop() + core.nvprof_nvtx_push(str(i)) + begin = time.time() + outputs = _run_main_iter() + runtimes.append(time.time() - begin) + + self._status = AFTER_RUN + stats = { + "framework": "paddle", + "version": paddle.__version__, + "name": self.name, + "device": "GPU" if use_gpu else "CPU", + "backward": self._backward, + "total": runtimes + } + + flop, byte = self.compute_flop_and_byte(config) + if flop is not None: + stats["flop"] = flop + if byte is not None: + stats["byte"] = byte + return outputs, stats + + def run(self, config, args, feeder_adapter=None): + paddle.disable_static() + self.name = config.api_name + + self._reset() + self._feed_spec = feeder.copy_feed_spec(config.feed_spec) + self._need_fetch = args.task == "accuracy" + if feeder_adapter: + self._feed_values = feeder_adapter.to_paddle() + outputs, stats = self.run_impl( + use_gpu=args.use_gpu, + config=config, + repeat=args.repeat, + nvprof_start_step=args.nvprof_start_step, + nvprof_end_step=args.nvprof_end_step, + profiler=args.profiler, + feeder_adapter=feeder_adapter) + return outputs, stats + + def _reset(self): + self.feed_list = None + self.fetch_list = None + self._feed_spec = None + self._feed_values = None + self._feed_list = [] + self._backward = False + self._status = BEFORE_RUN + self._feed_dict = {} + self._layers_function = None + self._ones_like_targets = [] diff --git a/api/dynamic_scheduling_tests_v2/__init__.py b/api/dynamic_scheduling_tests_v2/__init__.py new file mode 100644 index 0000000000..6f0ea85344 --- /dev/null +++ b/api/dynamic_scheduling_tests_v2/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/api/dynamic_scheduling_tests_v2/abs.py b/api/dynamic_scheduling_tests_v2/abs.py new file mode 100644 index 0000000000..58a76ee056 --- /dev/null +++ b/api/dynamic_scheduling_tests_v2/abs.py @@ -0,0 +1,39 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from common_import import * + + +class AbsConfig(APIConfig): + def __init__(self): + super(AbsConfig, self).__init__("abs") + self.feed_spec = {"range": [-1, 1]} + # abs belongs to activation op series which only has one parameter + # thus abs can reuse activation.json. + self.alias_name = "activation" + + +class PaddleAbs(PaddleDynamicSchedulingBenchmarkBase): + def build_graph(self, config): + x = self.variable(name="x", shape=config.x_shape, dtype=config.x_dtype) + result = paddle.abs(x=x) + + self.feed_list = [x] + self.fetch_list = [result] + if config.backward: + self.append_gradients(result, [x]) + + +if __name__ == '__main__': + test_main(pd_dy_obj=PaddleAbs(), config=AbsConfig()) diff --git a/api/dynamic_scheduling_tests_v2/run.sh b/api/dynamic_scheduling_tests_v2/run.sh new file mode 100644 index 0000000000..076f290539 --- /dev/null +++ b/api/dynamic_scheduling_tests_v2/run.sh @@ -0,0 +1,49 @@ +#!/bin/bash + +[ -z "$(set | grep '^CUDA_VISIBLE_DEVICES=')" ] && export CUDA_VISIBLE_DEVICES="0" # Set to "" if testing CPU + +NVCC=`which nvcc` +if [ ${NVCC} != "" ]; then + NVCC_VERSION=`nvcc --version | tail -n 1 | grep "[0-9][0-9]*\.[0-9]" -o | uniq` + export LD_LIBRARY_PATH=/usr/local/cuda-${NVCC_VERSION}/extras/CUPTI/lib64:${LD_LIBRARY_PATH} +fi + +OP_BENCHMARK_ROOT="$( cd "$( dirname "${BASH_SOURCE[0]}")/../" && pwd )" +export PYTHONPATH=${OP_BENCHMARK_ROOT}:${PYTHONPATH} + +name=${1:-"abs"} +config_id=${2:-"0"} +task=${3:-"accuracy"} # "accuracy" or "speed" + +testing_mode="dynamic" # "static" or "dynamic" +framework="pytorch" # "paddle" or "tensorflow" or "pytorch" +filename="${OP_BENCHMARK_ROOT}/tests_v2/configs/${name}.json" +if [ -z "$CUDA_VISIBLE_DEVICES" ]; then + use_gpu=False +else + use_gpu=True + #export FLAGS_cudnn_exhaustive_search=true +fi + +run_args="--task ${task} \ + --framework ${framework} \ + --testing_mode ${testing_mode} \ + --json_file ${filename} \ + --config_id ${config_id} \ + --profiler none \ + --backward False \ + --use_gpu ${use_gpu} \ + --repeat 1 \ + --nvprof_start_step 10 \ + --nvprof_end_step 100 \ + --allow_adaptive_repeat False \ + --log_level 0" + +if [ $# -ge 4 ]; then + api_name=${4} + run_args="${run_args} \ + --api_name ${api_name}" +fi + +python -m common.launch ${OP_BENCHMARK_ROOT}/dynamic_tests_v2/${name}.py \ + ${run_args} From 1b216b8ef67ecfd7a7d57fef97aa2d2830169ce9 Mon Sep 17 00:00:00 2001 From: pangyoki Date: Thu, 6 Jan 2022 12:25:07 +0000 Subject: [PATCH 2/3] fix bug --- api/common/launch.py | 53 +++++++---- api/common/main.py | 6 +- .../common_import.py | 91 +++++++++++++++++++ api/dynamic_scheduling_tests_v2/run.sh | 8 +- 4 files changed, 134 insertions(+), 24 deletions(-) create mode 100644 api/dynamic_scheduling_tests_v2/common_import.py diff --git a/api/common/launch.py b/api/common/launch.py index 019638d85b..db7ef4f800 100644 --- a/api/common/launch.py +++ b/api/common/launch.py @@ -172,33 +172,35 @@ class NsightRunnerForDynamicScheduling(object): def run(self, cmd, op_type, nvprof_start_step, nvprof_end_step, backward): stdout, exit_code = self._nsight_for_dynamic_scheduling(cmd) if exit_code == 0: - parse_status, gpu_time = self._parse_logs( + parse_status, scheduling_time_dict = self._parse_logs( stdout.split("\n"), op_type, nvprof_start_step, nvprof_end_step, backward) if parse_status: - return gpu_time + return scheduling_time_dict print("Running Error:\n {}".format(stdout)) - return 0.0 + return {} def _nsight_for_dynamic_scheduling(self, cmd): return system.run_command( "nsys profile -t cuda,nvtx --stats true -o tmp.qdrep --force-overwrite true {}". format(cmd)) - def _to_float(s): + def _to_float(self, s): return float(s.replace(',', '')) - def _calculate_avg_time(l): - total_time = _to_float(l[1]) - max_time = _to_float(l[5]) - calls = _to_float(l[2]) - 1 + def _calculate_avg_time(self, l): + total_time = self._to_float(l[1]) + max_time = self._to_float(l[5]) + calls = self._to_float(l[2]) - 1 return (total_time - max_time) / calls def _parse_logs(self, logs, op_type, nvprof_start_step, nvprof_end_step, backward): + print("yoki", logs) flag_nvtx_time = False total_step_time = 0.0 step_count = 0 + parse_status = False # 0: imperative (imperative_avg_time) # 1: op_type (fwd_trace_op_avg_time) @@ -227,11 +229,11 @@ def _parse_logs(self, logs, op_type, nvprof_start_step, nvprof_end_step, nvtx_range_type) > nvprof_start_step and int( nvtx_range_type) < nvprof_end_step: step_count += 1 - step_time = _to_float(infos[1]) + step_time = self._to_float(infos[1]) total_step_time += step_time if nvtx_range_type in _scheduling_list: - avg_time = _calculate_avg_time(infos) + avg_time = self._calculate_avg_time(infos) nvtx_meta_data_dict[nvtx_range_type] = avg_time # print(nvtx_range_type + ' time: ', avg_time) @@ -241,17 +243,23 @@ def _parse_logs(self, logs, op_type, nvprof_start_step, nvprof_end_step, nvtx_meta_data_dict['step'] = total_step_time / step_count # print("num_step: ", step_count, " step_avg_time: ", total_step_time / step_count) - scheduling_time_dict['step_avg_time'] = nvtx_meta_data_dict['step'] + scheduling_time_dict['step_avg_time'] = nvtx_meta_data_dict[ + 'step'] if 'step' in nvtx_meta_data_dict else None scheduling_time_dict['imperative_avg_time'] = nvtx_meta_data_dict[ - _scheduling_list[0]] + _scheduling_list[0]] if _scheduling_list[ + 0] in nvtx_meta_data_dict else None scheduling_time_dict['fwd_trace_op_avg_time'] = nvtx_meta_data_dict[ - _scheduling_list[1]] + _scheduling_list[1]] if _scheduling_list[ + 1] in nvtx_meta_data_dict else None scheduling_time_dict['fwd_op_compute_avg_time'] = nvtx_meta_data_dict[ - _scheduling_list[2]] + _scheduling_list[2]] if _scheduling_list[ + 2] in nvtx_meta_data_dict else None scheduling_time_dict['bwd_trace_op_avg_time'] = nvtx_meta_data_dict[ - _scheduling_list[3]] + _scheduling_list[3]] if _scheduling_list[ + 3] in nvtx_meta_data_dict else None scheduling_time_dict['bwd_op_compute_avg_time'] = nvtx_meta_data_dict[ - _scheduling_list[4]] + _scheduling_list[4]] if _scheduling_list[ + 4] in nvtx_meta_data_dict else None if scheduling_time_dict['step_avg_time'] and scheduling_time_dict[ 'imperative_avg_time']: if not backward: @@ -287,8 +295,10 @@ def _parse_logs(self, logs, op_type, nvprof_start_step, nvprof_end_step, 'bwd_trace_op_avg_time'] - scheduling_time_dict[ 'bwd_op_compute_avg_time'] + parse_status = True + print(scheduling_time_dict) - return scheduling_time_dict + return parse_status, scheduling_time_dict def launch(benchmark_script, @@ -381,6 +391,15 @@ def _args_list_to_dict(arg_list): system.check_commit() + if use_gpu and task == "scheduling" and profiler == "none": + total_gpu_time = launch( + args.benchmark_script, + args.benchmark_script_args, + with_nvprof=False, + with_dynamic_scheduling=True) + args.benchmark_script_args.append(" --gpu_time ") + args.benchmark_script_args.append(str(total_gpu_time)) + if use_gpu and task == "speed" and profiler == "none": total_gpu_time = launch( args.benchmark_script, diff --git a/api/common/main.py b/api/common/main.py index a988b1ba75..45d3cab6cd 100644 --- a/api/common/main.py +++ b/api/common/main.py @@ -45,7 +45,7 @@ def parse_args(): '--task', type=str, default="speed", - help='Specify the task: [speed|accuracy]') + help='Specify the task: [speed|accuracy|scheduling]') parser.add_argument( '--testing_mode', type=str, @@ -120,8 +120,8 @@ def parse_args(): parser.add_argument( '--log_level', type=int, default=0, help='level of logging') args = parser.parse_args() - if args.task not in ["speed", "accuracy"]: - raise ValueError("task should be speed, accuracy") + if args.task not in ["speed", "accuracy", "scheduling"]: + raise ValueError("task should be speed, accuracy, scheduling") if args.framework not in [ "paddle", "tensorflow", "tf", "pytorch", "torch", "both" ]: diff --git a/api/dynamic_scheduling_tests_v2/common_import.py b/api/dynamic_scheduling_tests_v2/common_import.py new file mode 100644 index 0000000000..cc1f30511d --- /dev/null +++ b/api/dynamic_scheduling_tests_v2/common_import.py @@ -0,0 +1,91 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os, sys +import numpy as np + +try: + import paddle +except ImportError: + sys.stderr.write("Cannot import paddle, maybe paddle is not installed.\n") + +try: + import torch +except ImportError: + sys.stderr.write("Cannot import pytorch, maybe paddle is not installed.\n") + +package_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +sys.path.append(package_path) + +from common.paddle_dynamic_scheduling_api_benchmark import PaddleDynamicSchedulingBenchmarkBase +from common.paddle_dynamic_api_benchmark import PaddleDynamicAPIBenchmarkBase +from common.pytorch_api_benchmark import PytorchAPIBenchmarkBase +from common.api_param import APIConfig +from common.main import test_main, test_main_without_json + + +def use_gpu(): + return os.environ.get("CUDA_VISIBLE_DEVICES", None) != "" + + +def unsqueeze_short(short, long): + """ + Unsqueeze the short shape to the same length of the long's. + For example: short is [16, 2048] and long is [16, 2048, 7, 7], + it will return [16, 2048, 1, 1]. + """ + # Extend short with 0s. + short_extend_zeros = np.zeros([len(long)], dtype=np.int32).tolist() + start = 0 + for value in short: + for i in range(start, len(long)): + if long[i] == value: + short_extend_zeros[i] = value + start = i + break + # Remove the 0s on the front and change 0s on the middle to 1s, [0, M, 0, N] -> [M, 1, N] + short_extend = [] + first_nonzero_idx = -1 + for i in range(len(short_extend_zeros)): + if first_nonzero_idx == -1 and short_extend_zeros[i] != 0: + first_nonzero_idx = i + if first_nonzero_idx > -1: + if short_extend_zeros[i] == 0: + short_extend.append(1) + else: + short_extend.append(short_extend_zeros[i]) + return short_extend + + +def numel(shape): + assert isinstance( + shape, list), "Expect shape to be a list, but recieved {}".format( + type(shape)) + return np.prod(np.array(shape)) + + +def sizeof(dtype): + assert isinstance( + dtype, str), "Expect dtype to be a string, but recieved {}".format( + type(dtype)) + if dtype in ["float64", "double", "int64", "long"]: + return 8 + elif dtype in ["float32", "float", "int32", "int"]: + return 4 + elif dtype in ["float16", "bfloat16"]: + return 2 + elif dtype in ["bool"]: + return 1 + else: + raise ValueError("{} is not supported.".format(dtype)) diff --git a/api/dynamic_scheduling_tests_v2/run.sh b/api/dynamic_scheduling_tests_v2/run.sh index 076f290539..c2f903d97c 100644 --- a/api/dynamic_scheduling_tests_v2/run.sh +++ b/api/dynamic_scheduling_tests_v2/run.sh @@ -13,10 +13,10 @@ export PYTHONPATH=${OP_BENCHMARK_ROOT}:${PYTHONPATH} name=${1:-"abs"} config_id=${2:-"0"} -task=${3:-"accuracy"} # "accuracy" or "speed" +task=${3:-"scheduling"} # "accuracy" or "speed" or "scheduling" testing_mode="dynamic" # "static" or "dynamic" -framework="pytorch" # "paddle" or "tensorflow" or "pytorch" +framework="paddle" # "paddle" or "tensorflow" or "pytorch" filename="${OP_BENCHMARK_ROOT}/tests_v2/configs/${name}.json" if [ -z "$CUDA_VISIBLE_DEVICES" ]; then use_gpu=False @@ -45,5 +45,5 @@ if [ $# -ge 4 ]; then --api_name ${api_name}" fi -python -m common.launch ${OP_BENCHMARK_ROOT}/dynamic_tests_v2/${name}.py \ - ${run_args} +python -m common.launch ${OP_BENCHMARK_ROOT}/dynamic_scheduling_tests_v2/${name}.py \ + ${run_args} true From 27b177585e10c22bb86c10a36e243377bb7a4625 Mon Sep 17 00:00:00 2001 From: pangyoki Date: Mon, 24 Jan 2022 09:58:56 +0000 Subject: [PATCH 3/3] fix bug --- api/common/launch.py | 2 +- api/dynamic_scheduling_tests_v2/run.sh | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/api/common/launch.py b/api/common/launch.py index db7ef4f800..7582ee3977 100644 --- a/api/common/launch.py +++ b/api/common/launch.py @@ -196,7 +196,7 @@ def _calculate_avg_time(self, l): def _parse_logs(self, logs, op_type, nvprof_start_step, nvprof_end_step, backward): - print("yoki", logs) + # print("yoki", logs) flag_nvtx_time = False total_step_time = 0.0 step_count = 0 diff --git a/api/dynamic_scheduling_tests_v2/run.sh b/api/dynamic_scheduling_tests_v2/run.sh index c2f903d97c..8a2859022a 100644 --- a/api/dynamic_scheduling_tests_v2/run.sh +++ b/api/dynamic_scheduling_tests_v2/run.sh @@ -46,4 +46,4 @@ if [ $# -ge 4 ]; then fi python -m common.launch ${OP_BENCHMARK_ROOT}/dynamic_scheduling_tests_v2/${name}.py \ - ${run_args} true + ${run_args}