Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fuse Split-Reduce with MLIR #3319

Merged
merged 195 commits into from
Aug 21, 2024
Merged

Fuse Split-Reduce with MLIR #3319

merged 195 commits into from
Aug 21, 2024

Conversation

umangyadav
Copy link
Member

@umangyadav umangyadav commented Jul 30, 2024

Part of #3212

Depends on #3097 #3299 and ROCm/rocMLIR#1590

Copy link
Collaborator

@CharlieL7 CharlieL7 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should also have a compiler pass test for the new fusion, right?

@umangyadav
Copy link
Member Author

We should also have a compiler pass test for the new fusion, right?

Yeah. They are a bit tricky to write. Let me add a one/two. I have verify test otherwise.

auto mlir_output_with_attrs =
migraphx::interpolate_string(mlir_output, {{"attrs", get_attrs()}});
CHECK(encode(s) == encode(mlir_output_with_attrs));
// EXPECT(verify_mlir(m));
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

verify is failing. Therefore disabling it for now. Could be an issue with rocMLIR.

@umangyadav
Copy link
Member Author

We should also have a compiler pass test for the new fusion, right?

Added tests

migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {2, 32, 10, 64, 64}}}), b);
auto fused =
add_mlir(p2,
"mlir_main:pointwise0_main:split_reduce0",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor: why do we lose "convolution" in the name the MLIR instruction?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Names are constructed from modules that are fused. convolution or dot would appear as mlir_op[op="" attribute.

operation op = make_op("convolution");

Comment on lines 1143 to 1151
for(const auto i : iterator_for(m))
{
if(starts_with(i->name(), "@"))
{
continue;
}
problem_config += " " + i->name();
}
tc.problem = problem_config;
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had these changes to experiment and work around problem_cache issue. Reverting these

{
param_map_2[skip_input] = skip_input;
}
return main_mod.fuse(*sub_pm, inputs, &param_map_2).front();
Copy link
Collaborator

@pfultz2 pfultz2 Aug 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, fuse is a poor choice here. Thats why you need to skip the parameters in the param map. Also, it doesnt insert instruction at pos. Instead we should add a insert_inline method that can insert the instructions correctly:

std::vector<instruction_ref>
module::insert_inline(instruction_ref ins,
                      const module& m,
                      const std::vector<instruction_ref>& inputs,
                      std::unordered_map<instruction_ref, instruction_ref>* map_ins,
                      module::inserter insert)
{
    std::unordered_map<instruction_ref, instruction_ref> default_map_ins;
    if(map_ins == nullptr)
        map_ins = &default_map_ins;
    auto param_map = m.get_ins_param_map(inputs, true);
    map_ins.insert(param_map.begin(), param_map.end());
    return this->insert_instructions(ins, &m, map_ins, std::move(insert));
}

Then you can do main_mod.insert_inline(pos, *sub_pm, inputs, &param_map_2).front(), and you can skip the skip_input changes.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Thanks.

@umangyadav umangyadav requested a review from pfultz2 August 18, 2024 13:01
@migraphx-bot
Copy link
Collaborator

Test Batch Rate new
94e112
Rate old
05b2ff
Diff Compare
torchvision-resnet50 64 3,233.91 3,238.29 -0.14%
torchvision-resnet50_fp16 64 6,887.36 6,890.63 -0.05%
torchvision-densenet121 32 2,428.79 2,427.57 0.05%
torchvision-densenet121_fp16 32 4,081.01 4,070.04 0.27%
torchvision-inceptionv3 32 1,633.94 1,634.43 -0.03%
torchvision-inceptionv3_fp16 32 2,742.24 2,737.22 0.18%
cadene-inceptionv4 16 770.98 771.30 -0.04%
cadene-resnext64x4 16 807.25 806.92 0.04%
slim-mobilenet 64 7,437.40 7,442.09 -0.06%
slim-nasnetalarge 64 207.38 207.44 -0.03%
slim-resnet50v2 64 3,340.00 3,342.32 -0.07%
bert-mrpc-onnx 8 1,148.01 1,152.95 -0.43%
bert-mrpc-tf 1 309.91 309.74 0.06%
pytorch-examples-wlang-gru 1 418.38 512.77 -18.41% 🔴
pytorch-examples-wlang-lstm 1 388.16 387.70 0.12%
torchvision-resnet50_1 1 767.53 804.05 -4.54% 🔴
cadene-dpn92_1 1 431.92 395.66 9.16% 🔆
cadene-resnext101_1 1 379.02 374.54 1.20%
onnx-taau-downsample 1 343.93 344.49 -0.16%
dlrm-criteoterabyte 1 35.08 35.05 0.07%
dlrm-criteoterabyte_fp16 1 57.25 57.31 -0.11%
agentmodel 1 8,174.68 8,142.79 0.39%
unet_fp16 2 57.77 57.75 0.04%
resnet50v1_fp16 1 933.75 929.86 0.42%
resnet50v1_int8 1 945.60 922.95 2.45%
bert_base_cased_fp16 64 1,141.42 1,142.41 -0.09%
bert_large_uncased_fp16 32 351.78 351.90 -0.03%
bert_large_fp16 1 211.18 208.73 1.18%
distilgpt2_fp16 16 2,153.21 2,155.12 -0.09%
yolov5s 1 503.72 503.82 -0.02%
tinyllama 1 43.34 43.36 -0.04%
vicuna-fastchat 1 177.12 175.40 0.98%
whisper-tiny-encoder 1 409.80 410.24 -0.11%
whisper-tiny-decoder 1 427.53 426.66 0.20%

This build is not recommended to merge 🔴

@migraphx-bot
Copy link
Collaborator


     ✅ bert-mrpc-onnx: PASSED: MIGraphX meets tolerance

     ✅ bert-mrpc-tf: PASSED: MIGraphX meets tolerance

     ✅ pytorch-examples-wlang-gru: PASSED: MIGraphX meets tolerance

     ✅ pytorch-examples-wlang-lstm: PASSED: MIGraphX meets tolerance

     ✅ torchvision-resnet50_1: PASSED: MIGraphX meets tolerance

     ✅ cadene-dpn92_1: PASSED: MIGraphX meets tolerance

     ✅ cadene-resnext101_1: PASSED: MIGraphX meets tolerance

     ✅ dlrm-criteoterabyte: PASSED: MIGraphX meets tolerance

     ✅ agentmodel: PASSED: MIGraphX meets tolerance

     ✅ unet: PASSED: MIGraphX meets tolerance

     ✅ resnet50v1: PASSED: MIGraphX meets tolerance

     ✅ bert_base_cased_fp16: PASSED: MIGraphX meets tolerance

🔴bert_large_uncased_fp16: FAILED: MIGraphX is not within tolerance - check verbose output


     ✅ bert_large: PASSED: MIGraphX meets tolerance

     ✅ yolov5s: PASSED: MIGraphX meets tolerance

     ✅ tinyllama: PASSED: MIGraphX meets tolerance

     ✅ vicuna-fastchat: PASSED: MIGraphX meets tolerance

     ✅ whisper-tiny-encoder: PASSED: MIGraphX meets tolerance

     ✅ whisper-tiny-decoder: PASSED: MIGraphX meets tolerance

     ✅ distilgpt2_fp16: PASSED: MIGraphX meets tolerance

@@ -50,6 +50,8 @@ struct MIGRAPHX_GPU_EXPORT mlir_code_object
std::vector<value> prefill_values = {};
};

MIGRAPHX_GPU_EXPORT bool is_reduce(const instruction& ins);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont see this used outside of mlir.cpp. I think it can be removed from the header.

Copy link
Member Author

@umangyadav umangyadav Aug 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is being used on both fuse_mlir.cpp and mlir.cpp

@causten causten merged commit 7ab413f into develop Aug 21, 2024
45 of 48 checks passed
@causten causten deleted the mlir-split-reduce branch August 21, 2024 14:19
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Fuse reductions with MLIR with multi-outputs
5 participants