diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS new file mode 100644 index 00000000000..5680a960831 --- /dev/null +++ b/.github/CODEOWNERS @@ -0,0 +1 @@ +* @causten diff --git a/src/simplify_reshapes.cpp b/src/simplify_reshapes.cpp index d8429d843e7..1b4fa586379 100644 --- a/src/simplify_reshapes.cpp +++ b/src/simplify_reshapes.cpp @@ -25,6 +25,7 @@ #include #include #include +#include #include #include #include @@ -280,6 +281,78 @@ struct find_concat_multibroadcasts } }; +struct find_concat_slice +{ + auto matcher() const + { + return match::name("concat")(match::any_of[match::outputs()](match::name("slice"))); + } + + void apply(module& m, const match::matcher_result& mr) const + { + auto ins = mr.result; + auto inputs = ins->inputs(); + auto outs = ins->outputs(); + std::vector slice_ins; + migraphx::transform_if( + outs.begin(), + outs.end(), + std::back_inserter(slice_ins), + [&](const auto& oins) { return oins->name() == "slice"; }, + [&](const auto& oins) { return oins; }); + int concat_axis = any_cast(ins->get_operator()).axis; + // prune slice candidates + std::vector slice_candidates; + for(const auto& sins : range(slice_ins.begin(), slice_ins.end())) + { + auto sop = any_cast(sins->get_operator()); + // slices with only one axis is allowed, because concat happens only one axis + if(sop.axes.size() != 1 or sop.axes.front() != concat_axis) + { + continue; + } + slice_candidates.push_back(sins); + } + if(slice_candidates.empty()) + { + return; + } + std::vector prefix_scan = {0}; + std::transform( + inputs.begin(), inputs.end(), std::back_inserter(prefix_scan), [&](const auto& i) { + return prefix_scan.back() + i->get_shape().lens()[concat_axis]; + }); + for(const auto& sins : slice_candidates) + { + auto sop = any_cast(sins->get_operator()); + size_t slice_start = sop.starts.front(); + size_t slice_len = sop.ends.front() - slice_start; + auto fii = std::find_if(prefix_scan.begin(), prefix_scan.end(), [&](const auto& j) { + return j == slice_start; + }); + if(fii == prefix_scan.end()) + { + continue; + } + // slice_len == 0 + else if(fii == prefix_scan.end() - 1) + { + assert(slice_len == 0 or slice_start >= prefix_scan.back()); + continue; + } + else + { + size_t idx = std::distance(prefix_scan.begin(), fii); + if(inputs[idx]->get_shape().lens()[concat_axis] == slice_len) + { + assert((prefix_scan[idx + 1] - prefix_scan[idx]) == slice_len); + m.replace_instruction(sins, inputs[idx]); + } + } + } + } +}; + struct find_concat_transpose { auto matcher() const @@ -806,6 +879,55 @@ struct find_transpose_slice } }; +struct find_reshape_reshape_dot +{ + auto matcher() const + { + return match::name("dot")(match::used_once(), + match::args(match::name("reshape").bind("inp_rsp1"), + match::name("reshape").bind("inp_rsp2"))); + } + + // Gemm axis should not be altered by the reshape + auto is_valid_reshape(instruction_ref in, instruction_ref rsp) const + { + auto in_lens = in->get_shape().lens(); + auto rsp_lens = rsp->get_shape().lens(); + + return std::equal(rsp_lens.end() - 2, rsp_lens.end(), in_lens.end() - 2, in_lens.end()); + } + + // Batch dims should match for both inputs + auto is_valid_inputs(instruction_ref in1, instruction_ref in2) const + { + auto in1_lens = in1->get_shape().lens(); + auto in2_lens = in2->get_shape().lens(); + + return ( + in1_lens.size() == in2_lens.size() and + std::equal(in1_lens.begin(), in1_lens.end() - 2, in2_lens.begin(), in2_lens.end() - 2)); + } + + void apply(module& m, const match::matcher_result& r) const + { + auto dot = r.result; + auto inp_rsp1 = r.instructions["inp_rsp1"]; + auto inp_rsp2 = r.instructions["inp_rsp2"]; + + auto dot_lens = dot->get_shape().lens(); + + auto inp1 = inp_rsp1->inputs().front(); + auto inp2 = inp_rsp2->inputs().front(); + + if(not(is_valid_reshape(inp1, inp_rsp1) and is_valid_reshape(inp2, inp_rsp2) and + is_valid_inputs(inp1, inp2))) + return; + + auto new_dot = m.insert_instruction(dot, dot->get_operator(), inp1, inp2); + m.replace_instruction(dot, make_op("reshape", {{"dims", dot_lens}}), new_dot); + } +}; + void simplify_reshapes::apply(module& m) const { for(int i = 0; i < depth; i++) @@ -817,6 +939,7 @@ void simplify_reshapes::apply(module& m) const find_reshaper{}, find_reshape_cont{}, find_transpose{}, + find_concat_slice{}, find_concat_transpose{}, find_concat_multibroadcasts{}, find_nested_slice{}, @@ -824,7 +947,8 @@ void simplify_reshapes::apply(module& m) const find_transpose_slice{}, find_broadcast_transpose{}, find_slice_transpose{}, - find_transpose_contiguous_reshaper_unary{}); + find_transpose_contiguous_reshaper_unary{}, + find_reshape_reshape_dot{}); dead_code_elimination{}.apply(m); } } diff --git a/test/simplify_reshapes_test.cpp b/test/simplify_reshapes_test.cpp index 7e8bb74619d..38c050d925f 100644 --- a/test/simplify_reshapes_test.cpp +++ b/test/simplify_reshapes_test.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -556,6 +556,299 @@ TEST_CASE(nested_squeeze_reshape) EXPECT(m1 == m2); } +TEST_CASE(concat_slice_different_axis_1) +{ + auto s = migraphx::shape{migraphx::shape::float_type, {2, 160}}; + migraphx::module m1; + { + auto x = m1.add_parameter("x", s); + auto y = m1.add_parameter("y", s); + auto concat = m1.add_instruction(migraphx::make_op("concat", {{"axis", 1}}), x, y); + auto slice1 = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), concat); + auto slice2 = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), concat); + auto add = m1.add_instruction(migraphx::make_op("add"), slice1, slice2); + m1.add_return({add}); + } + migraphx::module m2 = m1; + run_pass(m1); + EXPECT(m1 == m2); +} + +TEST_CASE(concat_slice_different_axis_2) +{ + // two slices, one with same axis but other with different + auto s = migraphx::shape{migraphx::shape::float_type, {2, 160}}; + migraphx::module m1; + { + auto x = m1.add_parameter("x", s); + auto y = m1.add_parameter("y", s); + auto concat = m1.add_instruction(migraphx::make_op("concat", {{"axis", 1}}), x, y); + auto slice1 = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), concat); + auto slice2 = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {160}}, {"ends", {320}}}), + concat); + auto add = m1.add_instruction(migraphx::make_op("add"), x, slice2); + m1.add_return({slice1, add}); + } + run_pass(m1); + migraphx::module m2; + { + auto x = m2.add_parameter("x", s); + auto y = m2.add_parameter("y", s); + auto concat = m2.add_instruction(migraphx::make_op("concat", {{"axis", 1}}), x, y); + auto slice1 = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), concat); + auto add = m2.add_instruction(migraphx::make_op("add"), x, y); + m2.add_return({slice1, add}); + } + EXPECT(m1 == m2); +} + +TEST_CASE(concat_slice_in_same_order) +{ + auto s = migraphx::shape{migraphx::shape::float_type, {2, 160}}; + migraphx::module m1; + { + auto x = m1.add_parameter("x", s); + auto y = m1.add_parameter("y", s); + auto concat = m1.add_instruction(migraphx::make_op("concat", {{"axis", 1}}), x, y); + auto slice1 = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {160}}}), concat); + auto slice2 = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {160}}, {"ends", {320}}}), + concat); + auto add = m1.add_instruction(migraphx::make_op("add"), slice1, slice2); + m1.add_return({add}); + } + run_pass(m1); + migraphx::module m2; + { + auto x = m2.add_parameter("x", s); + auto y = m2.add_parameter("y", s); + auto add = m2.add_instruction(migraphx::make_op("add"), x, y); + m2.add_return({add}); + } + EXPECT(m1 == m2); +} + +TEST_CASE(concat_slice_in_reverse_order) +{ + auto s = migraphx::shape{migraphx::shape::float_type, {2, 160}}; + migraphx::module m1; + { + auto x = m1.add_parameter("x", s); + auto y = m1.add_parameter("y", s); + auto concat = m1.add_instruction(migraphx::make_op("concat", {{"axis", 1}}), x, y); + auto slice1 = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {160}}, {"ends", {320}}}), + concat); + auto slice2 = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {160}}}), concat); + auto add = m1.add_instruction(migraphx::make_op("add"), slice1, slice2); + m1.add_return({add}); + } + run_pass(m1); + migraphx::module m2; + { + auto x = m2.add_parameter("x", s); + auto y = m2.add_parameter("y", s); + auto add = m2.add_instruction(migraphx::make_op("add"), y, x); + m2.add_return({add}); + } + EXPECT(m1 == m2); +} + +TEST_CASE(concat_slice_inorder_with_empty_slice) +{ + auto s = migraphx::shape{migraphx::shape::float_type, {2, 160}}; + migraphx::module m1; + { + auto x = m1.add_parameter("x", s); + auto y = m1.add_parameter("y", s); + auto concat = m1.add_instruction(migraphx::make_op("concat", {{"axis", 1}}), x, y); + auto slice1 = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {160}}}), concat); + auto slice2 = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {160}}, {"ends", {320}}}), + concat); + auto slice3 = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {320}}, {"ends", {360}}}), + concat); + auto add = m1.add_instruction(migraphx::make_op("add"), slice1, slice2); + m1.add_return({add, slice3}); + } + run_pass(m1); + migraphx::module m2; + { + auto x = m2.add_parameter("x", s); + auto y = m2.add_parameter("y", s); + auto concat = m2.add_instruction(migraphx::make_op("concat", {{"axis", 1}}), x, y); + auto slice3 = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {320}}, {"ends", {360}}}), + concat); + auto add = m2.add_instruction(migraphx::make_op("add"), x, y); + m2.add_return({add, slice3}); + } + EXPECT(m1 == m2); +} + +TEST_CASE(concat_slice_uneven_len_1) +{ + auto s = migraphx::shape{migraphx::shape::float_type, {2, 160}}; + migraphx::module m1; + { + auto x = m1.add_parameter("x", s); + auto y = m1.add_parameter("y", s); + auto z = m1.add_parameter("z", s); + auto concat = m1.add_instruction(migraphx::make_op("concat", {{"axis", 1}}), x, y, z); + auto slice1 = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {100}}}), concat); + auto slice2 = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {100}}, {"ends", {160}}}), + concat); + auto slice3 = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {160}}, {"ends", {320}}}), + concat); + auto slice4 = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {320}}, {"ends", {420}}}), + concat); + auto slice5 = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {420}}, {"ends", {480}}}), + concat); + auto add1 = m1.add_instruction(migraphx::make_op("add"), slice1, slice4); + auto add2 = m1.add_instruction(migraphx::make_op("add"), slice2, slice5); + auto add3 = m1.add_instruction(migraphx::make_op("add"), slice3, z); + m1.add_return({add1, add2, add3}); + } + run_pass(m1); + migraphx::module m2; + { + auto x = m2.add_parameter("x", s); + auto y = m2.add_parameter("y", s); + auto z = m2.add_parameter("z", s); + auto concat = m2.add_instruction(migraphx::make_op("concat", {{"axis", 1}}), x, y, z); + auto slice1 = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {100}}}), concat); + auto slice2 = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {100}}, {"ends", {160}}}), + concat); + auto slice4 = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {320}}, {"ends", {420}}}), + concat); + auto slice5 = m2.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {420}}, {"ends", {480}}}), + concat); + auto add1 = m2.add_instruction(migraphx::make_op("add"), slice1, slice4); + auto add2 = m2.add_instruction(migraphx::make_op("add"), slice2, slice5); + auto add3 = m2.add_instruction(migraphx::make_op("add"), y, z); + m2.add_return({add1, add2, add3}); + } + EXPECT(m1 == m2); +} + +TEST_CASE(concat_slice_uneven_len_2) +{ + auto s = migraphx::shape{migraphx::shape::float_type, {2, 160}}; + migraphx::module m1; + { + auto x = m1.add_parameter("x", s); + auto y = m1.add_parameter("y", s); + auto concat = m1.add_instruction(migraphx::make_op("concat", {{"axis", 1}}), x, y); + auto slice1 = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {150}}}), concat); + auto slice2 = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {150}}, {"ends", {300}}}), + concat); + auto slice3 = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {300}}, {"ends", {320}}}), + concat); + auto add = m1.add_instruction(migraphx::make_op("add"), slice1, slice2); + m1.add_return({add, slice3}); + } + migraphx::module m2 = m1; + run_pass(m1); + EXPECT(m1 == m2); +} + +TEST_CASE(concat_slice_multiple_slice_use) +{ + // multiple use for slice1 and slice3, single use for slice2 + auto s = migraphx::shape{migraphx::shape::float_type, {2, 160}}; + migraphx::module m1; + { + auto x = m1.add_parameter("x", s); + auto y = m1.add_parameter("y", s); + auto z = m1.add_parameter("z", s); + auto concat = m1.add_instruction(migraphx::make_op("concat", {{"axis", 1}}), x, y, z); + auto slice1 = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {160}}}), concat); + auto slice2 = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {160}}, {"ends", {320}}}), + concat); + auto slice3 = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {320}}, {"ends", {480}}}), + concat); + auto add1 = m1.add_instruction(migraphx::make_op("add"), slice1, z); + auto add2 = m1.add_instruction(migraphx::make_op("add"), slice3, x); + auto sub1 = m1.add_instruction(migraphx::make_op("sub"), slice1, z); + auto sub2 = m1.add_instruction(migraphx::make_op("sub"), slice3, x); + auto add3 = m1.add_instruction(migraphx::make_op("add"), sub1, sub2); + auto sub3 = m1.add_instruction(migraphx::make_op("sub"), add3, slice2); + m1.add_return({add1, add2, sub3}); + } + run_pass(m1); + migraphx::module m2; + { + auto x = m2.add_parameter("x", s); + auto y = m2.add_parameter("y", s); + auto z = m2.add_parameter("z", s); + auto add1 = m2.add_instruction(migraphx::make_op("add"), x, z); + auto add2 = m2.add_instruction(migraphx::make_op("add"), z, x); + auto sub1 = m2.add_instruction(migraphx::make_op("sub"), x, z); + auto sub2 = m2.add_instruction(migraphx::make_op("sub"), z, x); + auto add3 = m2.add_instruction(migraphx::make_op("add"), sub1, sub2); + auto sub3 = m2.add_instruction(migraphx::make_op("sub"), add3, y); + m2.add_return({add1, add2, sub3}); + } + EXPECT(m1 == m2); +} + +TEST_CASE(concat_slice_with_multiple_concat_outs) +{ + auto s1 = migraphx::shape{migraphx::shape::float_type, {2, 160}}; + auto s2 = migraphx::shape{migraphx::shape::float_type, {2, 480}}; + migraphx::module m1; + { + auto x = m1.add_parameter("x", s1); + auto y = m1.add_parameter("y", s1); + auto z = m1.add_parameter("z", s1); + auto w = m1.add_parameter("w", s2); + auto concat = m1.add_instruction(migraphx::make_op("concat", {{"axis", 1}}), x, y, z); + auto slice1 = m1.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {160}}, {"ends", {320}}}), + concat); + auto add1 = m1.add_instruction(migraphx::make_op("add"), concat, w); + auto add2 = m1.add_instruction(migraphx::make_op("add"), slice1, z); + m1.add_return({add1, add2}); + } + run_pass(m1); + migraphx::module m2; + { + auto x = m2.add_parameter("x", s1); + auto y = m2.add_parameter("y", s1); + auto z = m2.add_parameter("z", s1); + auto w = m2.add_parameter("w", s2); + auto concat = m2.add_instruction(migraphx::make_op("concat", {{"axis", 1}}), x, y, z); + auto add1 = m2.add_instruction(migraphx::make_op("add"), concat, w); + auto add2 = m2.add_instruction(migraphx::make_op("add"), y, z); + m2.add_return({add1, add2}); + } + EXPECT(m1 == m2); +} + TEST_CASE(concat_multibroadcasts1) { // Broadcasted batch dim, new axis < old axis @@ -1697,4 +1990,54 @@ TEST_CASE(transpose_slice_non_packed_multi_axis) EXPECT(m1.sort() == m2.sort()); } +TEST_CASE(reshape_reshape_dot) +{ + migraphx::shape as{migraphx::shape::float_type, {2, 10, 32, 16}}; + migraphx::shape bs{migraphx::shape::float_type, {2, 10, 16, 32}}; + migraphx::module m1; + { + auto a = m1.add_literal(migraphx::generate_literal(as)); + auto b = m1.add_parameter("input", bs); + auto a_rsp = m1.add_instruction(migraphx::make_op("reshape", {{"dims", {20, 32, 16}}}), a); + auto b_rsp = m1.add_instruction(migraphx::make_op("reshape", {{"dims", {20, 16, 32}}}), b); + + auto dot = m1.add_instruction(migraphx::make_op("dot"), a_rsp, b_rsp); + auto dot_rsp = + m1.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 10, 32, 32}}}), dot); + m1.add_return({dot_rsp}); + }; + run_pass(m1); + + migraphx::module m2; + { + auto a = m2.add_literal(migraphx::generate_literal(as)); + auto b = m2.add_parameter("input", bs); + auto dot = m2.add_instruction(migraphx::make_op("dot"), a, b); + m2.add_return({dot}); + }; + + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(reshape_reshape_dot_gemm_axis) +{ + migraphx::shape as{migraphx::shape::float_type, {2, 10, 512}}; + migraphx::shape bs{migraphx::shape::float_type, {2, 10, 512}}; + migraphx::module m1; + { + auto a = m1.add_literal(migraphx::generate_literal(as)); + auto b = m1.add_parameter("input", bs); + auto a_rsp = m1.add_instruction(migraphx::make_op("reshape", {{"dims", {20, 32, 16}}}), a); + auto b_rsp = m1.add_instruction(migraphx::make_op("reshape", {{"dims", {20, 16, 32}}}), b); + + auto dot = m1.add_instruction(migraphx::make_op("dot"), a_rsp, b_rsp); + auto dot_rsp = + m1.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 10, 1024}}}), dot); + m1.add_return({dot_rsp}); + }; + migraphx::module m2 = m1; + run_pass(m1); + EXPECT(m1.sort() == m2.sort()); +} + int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/tools/format.py b/tools/format.py index df80876119a..f04b59b3609 100644 --- a/tools/format.py +++ b/tools/format.py @@ -1,7 +1,7 @@ ##################################################################################### # The MIT License (MIT) # -# Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -21,24 +21,34 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. ##################################################################################### -import os, shutil, argparse, subprocess +import os, shlex, shutil, argparse, subprocess CLANG_FORMAT_PATH = '/opt/rocm/llvm/bin' +EXCLUDE_FILES = ['requirements.in'] + def run(cmd, **kwargs): - print(cmd) - subprocess.run(cmd, shell=True, check=True, **kwargs) + if isinstance(cmd, str): + print(cmd) + else: + print(shlex.join(cmd)) + subprocess.run(cmd, shell=isinstance(cmd, str), check=True, **kwargs) def eval(cmd, **kwargs): return subprocess.run(cmd, capture_output=True, - shell=True, + shell=isinstance(cmd, str), check=True, **kwargs).stdout.decode('utf-8').strip() +def is_excluded(f): + base = os.path.basename(f) + return base in EXCLUDE_FILES + + def get_top(): return eval("git rev-parse --show-toplevel") @@ -52,6 +62,12 @@ def get_merge_base(branch): return eval(f"git merge-base {branch} {head}") +def get_files_changed(against, ext=('.py')): + files = eval(f"git diff-index --cached --name-only {against}", + cwd=get_top()).splitlines() + return (f for f in files if f.endswith(ext) and not is_excluded(f)) + + def clang_format(against, apply=False, path=CLANG_FORMAT_PATH): base = get_merge_base(against) clang_format = os.path.join(path, 'clang-format') @@ -62,15 +78,14 @@ def clang_format(against, apply=False, path=CLANG_FORMAT_PATH): if not os.path.exists(git_clang_format): print(f"{git_clang_format} not installed. Skipping format.") return - diff_flag = "" if apply else "--diff" - run(f"{git_clang_format} --extensions c,cpp,hpp,h,cl,hip,in --binary {clang_format} {diff_flag} {base}" - ) - - -def get_files_changed(against, ext=('py')): - files = eval(f"git diff-index --cached --name-only {against}", - cwd=get_top()).splitlines() - return (f for f in files if f.endswith(ext)) + diff_flag = [] if apply else ["--diff"] + files = list( + get_files_changed(base, + ext=('.c', '.cpp', '.hpp', '.h', '.cl', '.hip', + '.in'))) + run([git_clang_format, '--binary', clang_format] + diff_flag + [base] + + files, + cwd=get_top()) def yapf_format(against, apply=False): @@ -80,7 +95,7 @@ def yapf_format(against, apply=False): diff_flag = "--in-place" if apply else "--diff" files = ' '.join(get_files_changed(against)) if files: - run(f"yapf {diff_flag} -p {files}") + run(f"yapf {diff_flag} -p {files}", cwd=get_top()) else: print("No modified python files to format")