Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

当我在尝试剪枝YOLOV9时,DG.get_all_groups构建的group中Concat的通道索引不正确 #415

Open
EzcodingSen opened this issue Aug 27, 2024 · 3 comments

Comments

@EzcodingSen
Copy link

我的剪枝器设置:
example_inputs = torch.randn(1, 3, 640, 640).to(device)
ignored_layers = []
unwrapped_parameters = []
importance = tp.importance.GroupNormImportance(p=2)
pruner = tp.pruner.MetaPruner(
model,
example_inputs,
importance,
iterative_steps=1,
pruning_ratio=0.35,
ignored_layers=ignored_layers,
unwrapped_parameters=unwrapped_parameters,
)
构建完剪枝器,我尝试参照MetaPruner中的_prune函数打印group和每个group的imp进行查看:
for group in pruner.DG.get_all_groups(ignored_layers=pruner.ignored_layers, root_module_types=pruner.root_module_types):
if pruner._check_pruning_ratio(group):
group = pruner._downstream_node_as_root_if_attention(group)
ch_groups = pruner._get_channel_groups(group)
print(group)
imp = pruner.estimate_importance(group)
print(imp)

然后报错了:

      Pruning Group

[0] prune_out_channels on model.9.cv3.1.conv (Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)) => prune_out_channels on model.9.cv3.1.conv (Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), len(idxs)=256
[1] prune_out_channels on model.9.cv3.1.conv (Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)) => prune_out_channels on model.9.cv3.1.bn (BatchNorm2d(256, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)), len(idxs)=256
[2] prune_out_channels on model.9.cv3.1.bn (BatchNorm2d(256, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)) => prune_out_channels on _ElementWiseOp_21(SiluBackward0), len(idxs)=256
[3] prune_out_channels on _ElementWiseOp_21(SiluBackward0) => prune_out_channels on _ConcatOp_18([0, 1024, 2048, 2304, 2560]), len(idxs)=256
[4] prune_out_channels on _ConcatOp_18([0, 1024, 2048, 2304, 2560]) => prune_in_channels on model.9.cv4.conv (Conv2d(1024, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)), len(idxs)=256

........
........
/opt/conda/conda-bld/pytorch_1720538643151/work/aten/src/ATen/native/cuda/IndexKernel.cu:92: operator(): block: [0,0,0], thread: [60,0,0] Assertion -sizes[i] <= index && index < sizes[i] && "index out of bounds" failed.
/opt/conda/conda-bld/pytorch_1720538643151/work/aten/src/ATen/native/cuda/IndexKernel.cu:92: operator(): block: [0,0,0], thread: [61,0,0] Assertion -sizes[i] <= index && index < sizes[i] && "index out of bounds" failed.
/opt/conda/conda-bld/pytorch_1720538643151/work/aten/src/ATen/native/cuda/IndexKernel.cu:92: operator(): block: [0,0,0], thread: [62,0,0] Assertion -sizes[i] <= index && index < sizes[i] && "index out of bounds" failed.
/opt/conda/conda-bld/pytorch_1720538643151/work/aten/src/ATen/native/cuda/IndexKernel.cu:92: operator(): block: [0,0,0], thread: [63,0,0] Assertion -sizes[i] <= index && index < sizes[i] && "index out of bounds" failed.
Traceback (most recent call last):
File "/home/xs/yolov9/tp-prune.py", line 234, in
prune(model,save_path,device)
File "/home/xs/yolov9/tp-prune.py", line 76, in prune
imp = pruner.estimate_importance(group)
File "/root/anaconda3/envs/v9/lib/python3.9/site-packages/torch_pruning/pruner/algorithms/metapruner.py", line 279, in estimate_importance
return self.importance(group)
File "/root/anaconda3/envs/v9/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
File "/root/anaconda3/envs/v9/lib/python3.9/site-packages/torch_pruning/pruner/importance.py", line 269, in call
group_imp = self._reduce(group_imp, group_idxs)
File "/root/anaconda3/envs/v9/lib/python3.9/site-packages/torch_pruning/pruner/importance.py", line 149, in reduce
reduced_imp.scatter_add
(0, torch.tensor(root_idxs, device=imp.device), imp) # accumulated importance
RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with TORCH_USE_CUDA_DSA to enable device-side assertions.

简单概括这个错误是,索引对不上的问题.
我发现在这个group中:[4] prune_out_channels on _ConcatOp_18([0, 1024, 2048, 2304, 2560]) => prune_in_channels on model.9.cv4.conv (Conv2d(1024, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)), len(idxs)=256
Concat的索引不正确,正常的话应该是[0, 256, 512, 768, 1024],但是在这里是[0, 1024, 2048, 2304, 2560]。

于是我又到GroupNormImportance中去打印索引:
@torch.no_grad()
def call(self, group: Group):
group_imp = []
group_idxs = []
# Iterate over all groups and estimate group importance
for i, (dep, idxs) in enumerate(group):
layer = dep.layer
prune_fn = dep.pruning_fn
root_idxs = group[i].root_idxs
if not isinstance(layer, tuple(self.target_types)):
continue

        print(dep)
        print(layer)
        print(root_idxs)
        print(idxs)
        input()

输出:
prune_out_channels on _ConcatOp_18([0, 1024, 2048, 2304, 2560]) => prune_in_channels on model.9.cv4.conv (Conv2d(1024, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False))
Conv2d(1024, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255]
[2304, 2305, 2306, 2307, 2308, 2309, 2310, 2311, 2312, 2313, 2314, 2315, 2316, 2317, 2318, 2319, 2320, 2321, 2322, 2323, 2324, 2325, 2326, 2327, 2328, 2329, 2330, 2331, 2332, 2333, 2334, 2335, 2336, 2337, 2338, 2339, 2340, 2341, 2342, 2343, 2344, 2345, 2346, 2347, 2348, 2349, 2350, 2351, 2352, 2353, 2354, 2355, 2356, 2357, 2358, 2359, 2360, 2361, 2362, 2363, 2364, 2365, 2366, 2367, 2368, 2369, 2370, 2371, 2372, 2373, 2374, 2375, 2376, 2377, 2378, 2379, 2380, 2381, 2382, 2383, 2384, 2385, 2386, 2387, 2388, 2389, 2390, 2391, 2392, 2393, 2394, 2395, 2396, 2397, 2398, 2399, 2400, 2401, 2402, 2403, 2404, 2405, 2406, 2407, 2408, 2409, 2410, 2411, 2412, 2413, 2414, 2415, 2416, 2417, 2418, 2419, 2420, 2421, 2422, 2423, 2424, 2425, 2426, 2427, 2428, 2429, 2430, 2431, 2432, 2433, 2434, 2435, 2436, 2437, 2438, 2439, 2440, 2441, 2442, 2443, 2444, 2445, 2446, 2447, 2448, 2449, 2450, 2451, 2452, 2453, 2454, 2455, 2456, 2457, 2458, 2459, 2460, 2461, 2462, 2463, 2464, 2465, 2466, 2467, 2468, 2469, 2470, 2471, 2472, 2473, 2474, 2475, 2476, 2477, 2478, 2479, 2480, 2481, 2482, 2483, 2484, 2485, 2486, 2487, 2488, 2489, 2490, 2491, 2492, 2493, 2494, 2495, 2496, 2497, 2498, 2499, 2500, 2501, 2502, 2503, 2504, 2505, 2506, 2507, 2508, 2509, 2510, 2511, 2512, 2513, 2514, 2515, 2516, 2517, 2518, 2519, 2520, 2521, 2522, 2523, 2524, 2525, 2526, 2527, 2528, 2529, 2530, 2531, 2532, 2533, 2534, 2535, 2536, 2537, 2538, 2539, 2540, 2541, 2542, 2543, 2544, 2545, 2546, 2547, 2548, 2549, 2550, 2551, 2552, 2553, 2554, 2555, 2556, 2557, 2558, 2559]
证明确实是在DG构建时索引不正确。

请问我该如何解决这个问题?

@EzcodingSen
Copy link
Author

问题跟进:依赖图构建时跳过了chunk操作导致索引问题

@EzcodingSen
Copy link
Author

问题跟进:更换算子中的chunk操作为split.出现新的错误:split出的分支作为多layer输入时,torch_pruning._helpers._SplitIndexMapping对象的offset不正确.

@EzcodingSen
Copy link
Author

权宜之计:
split作为多layer输入时,超出的idx,都用split划分的最后一部分作为输入
更改 torch-pruning/dependency.py中

                       dep.index_mapping[0] = _helpers._SplitIndexMapping(
                            offset=offsets[i: i + 2], reverse=False
                        )

                    # 如果i超过了可用的offset部分数量,则使用最后一部分
                    if i < num_offsets:
                        dep.index_mapping[0] = _helpers._SplitIndexMapping(
                            offset=offsets[i: i + 2], reverse=False
                        )
                    else:
                        # 超出部分使用最后一部分
                        dep.index_mapping[0] = _helpers._SplitIndexMapping(
                            offset=offsets[-2:], reverse=False
                        )

更换RepNCSPELAN4,和ADown算子中chunk为split,跳层RepNCSPELAN4的cv1和cv4后,跳层CBFuse相关联层(待优化)后,跑通torch-pruning所支持的所有剪枝算法.
能力有限,实属无奈
有更好的处理方法,烦请赐教

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant