Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AttributeError: Can't get attribute '__main__' on <module 'builtins' (built-in)> while loading the pruned model #399

Open
jenyaMi opened this issue Jul 7, 2024 · 5 comments

Comments

@jenyaMi
Copy link

jenyaMi commented Jul 7, 2024

I use yolov8 pruning code, to be more precisely the updated version for the new version of ultralytics: https://github.com/chbw818/yolov8-prune-using-torch-pruning-/tree/main

I am struggling with loading the pruned model...

AttributeError: Can't get attribute 'main' on <module 'builtins' (built-in)>

AttributeError Traceback (most recent call last)
in
16 args = parser.parse_args()
17
---> 18 prune(args)
in prune(args)
432 pruning_cfg['name'] = f"step_{i}_finetune"
433 pruning_cfg['batch'] = batch_size # restore batch size
--> 434 model.train_v2(pruning=True, **pruning_cfg)
435
436 # post fine-tuning validation
in train_v2(self, pruning, **kwargs)
337
338 self.trainer.hub_session = self.session # attach optional HUB session
--> 339 self.trainer.train()
340 # Update model and cfg after training
341 if RANK in (-1, 0):
/workspace/data/notebooks/belt_detection/ultralytics/ultralytics/engine/trainer.py in train(self)
202
203 else:
--> 204 self._do_train(world_size)
205
206 def _setup_scheduler(self):
/workspace/data/notebooks/belt_detection/ultralytics/ultralytics/engine/trainer.py in _do_train(self, world_size)
467 f"{(time.time() - self.train_time_start) / 3600:.3f} hours."
468 )
--> 469 self.final_eval()
470 if self.args.plots:
471 self.plot_metrics()
in final_eval_v2(self)
272 for f in self.last, self.best:
273 if f.exists():
--> 274 strip_optimizer_v2(f) # strip optimizers
275 if f is self.best:
276 LOGGER.info(f'\nValidating {f}...')
in strip_optimizer_v2(f, s)
284 Disabled half precision saving. originated from ultralytics/yolo/utils/torch_utils.py
285 """
--> 286 x = torch.load(f, map_location=torch.device('cpu'))
287 args = {**DEFAULT_CFG_DICT, **x['train_args']} # combine model args with default args, preferring model args
288 if x.get('ema'):
/opt/conda/lib/python3.8/site-packages/torch/serialization.py in load(f, map_location, pickle_module, **pickle_load_args)
605 opened_file.seek(orig_position)
606 return torch.jit.load(opened_file)
--> 607 return _load(opened_zipfile, map_location, pickle_module, **pickle_load_args)
608 return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args)
/opt/conda/lib/python3.8/site-packages/torch/serialization.py in _load(zip_file, map_location, pickle_module, pickle_file, **pickle_load_args)
880 unpickler = UnpicklerWrapper(data_file, **pickle_load_args)
881 unpickler.persistent_load = persistent_load
--> 882 result = unpickler.load()
883
884 torch._utils._validate_loaded_sparse_tensors()
/opt/conda/lib/python3.8/site-packages/torch/serialization.py in find_class(self, mod_name, name)
873 def find_class(self, mod_name, name):
874 mod_name = load_module_mapping.get(mod_name, mod_name)
--> 875 return super().find_class(mod_name, name)
876
877 # Load the data (which may in turn use persistent_load to load tensors)
AttributeError: Can't get attribute 'main' on <module 'builtins' (built-in)>

Maybe someone know what could be the problem here, I saw many people were able to run the code successfully. I use the last version of ultralytics 8.2.50

@Ramzes30765
Copy link

Same for me. If you'll find any solution - please tag me

@jenyaMi
Copy link
Author

jenyaMi commented Jul 11, 2024

@Ramzes30765 hey
Adding pickle_module=dill explicitly in the torch.load function like this: torch.load('your_model.pt', pickle_module=dill) helped me. I added it in the strip_optimizer_v2 function and the torch_save_load function in the ultralytics/nn/tasks.py file.
let me know if you found something else

@Ramzes30765
Copy link

Ramzes30765 commented Jul 11, 2024

@jenyaMi Hello! Thank you for information

pickle_module=dill
Should I add this parameter in torch.save() too?

@Ramzes30765
Copy link

@

@Ramzes30765 hey Adding pickle_module=dill explicitly in the torch.load function like this: torch.load('your_model.pt', pickle_module=dill) helped me. I added it in the strip_optimizer_v2 function and the torch_save_load function in the ultralytics/nn/tasks.py file. let me know if you found something else

Still have an errors(

Image sizes 640 train, 640 val
Using 8 dataloader workers
Logging results to runs/detect/step_1_finetune
Starting training for 10 epochs...
Closing dataloader mosaic

      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size
  0%|          | 0/8 [00:00<?, ?it/s]
Traceback (most recent call last):
  File "/home/dalekseenko/models_optimizing/ultralytics/yolov8_pruning.py", line 396, in <module>
    prune(args)
  File "/home/dalekseenko/models_optimizing/ultralytics/yolov8_pruning.py", line 357, in prune
    model.train_v2(pruning=True, **pruning_cfg)
  File "/home/dalekseenko/models_optimizing/ultralytics/yolov8_pruning.py", line 266, in train_v2
    self.trainer.train()
  File "/home/dalekseenko/models_optimizing/ultralytics/ultralytics/engine/trainer.py", line 204, in train
    self._do_train(world_size)
  File "/home/dalekseenko/models_optimizing/ultralytics/ultralytics/engine/trainer.py", line 381, in _do_train
    self.loss, self.loss_items = self.model(batch)
                                 ^^^^^^^^^^^^^^^^^
  File "/home/dalekseenko/models_optimizing/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dalekseenko/models_optimizing/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dalekseenko/models_optimizing/ultralytics/ultralytics/nn/tasks.py", line 101, in forward
    return self.loss(x, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dalekseenko/models_optimizing/ultralytics/ultralytics/nn/tasks.py", line 283, in loss
    return self.criterion(preds, batch)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dalekseenko/models_optimizing/ultralytics/ultralytics/utils/loss.py", line 229, in __call__
    pred_bboxes = self.bbox_decode(anchor_points, pred_distri)  # xyxy, (b, h*w, 4)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dalekseenko/models_optimizing/ultralytics/ultralytics/utils/loss.py", line 201, in bbox_decode
    pred_dist = pred_dist.view(b, a, 4, c // 4).softmax(3).matmul(self.proj.type(pred_dist.dtype))
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument mat in method wrapper_CUDA_addmv_)
Exception in thread Thread-42 (_pin_memory_loop):
Traceback (most recent call last):
  File "/usr/lib/python3.12/threading.py", line 1073, in _bootstrap_inner
    self.run()
  File "/usr/lib/python3.12/threading.py", line 1010, in run
    self._target(*self._args, **self._kwargs)
  File "/home/dalekseenko/models_optimizing/.venv/lib/python3.12/site-packages/torch/utils/data/_utils/pin_memory.py", line 54, in _pin_memory_loop
    do_one_step()
  File "/home/dalekseenko/models_optimizing/.venv/lib/python3.12/site-packages/torch/utils/data/_utils/pin_memory.py", line 31, in do_one_step
    r = in_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/queues.py", line 122, in get
    return _ForkingPickler.loads(res)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dalekseenko/models_optimizing/.venv/lib/python3.12/site-packages/torch/multiprocessing/reductions.py", line 495, in rebuild_storage_fd
    fd = df.detach()
         ^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/resource_sharer.py", line 57, in detach
    with _resource_sharer.get_connection(self._id) as conn:
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/resource_sharer.py", line 86, in get_connection
    c = Client(address, authkey=process.current_process().authkey)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/connection.py", line 525, in Client
    answer_challenge(c, authkey)
  File "/usr/lib/python3.12/multiprocessing/connection.py", line 953, in answer_challenge
    message = connection.recv_bytes(256)         # reject large message
              ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/connection.py", line 216, in recv_bytes
    buf = self._recv_bytes(maxlength)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/connection.py", line 430, in _recv_bytes
    buf = self._recv(4)
          ^^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/connection.py", line 395, in _recv
    chunk = read(handle, remaining)
            ^^^^^^^^^^^^^^^^^^^^^^^
ConnectionResetError: [Errno 104] Connection reset by peer

@YMCAlan
Copy link

YMCAlan commented Jul 19, 2024

I encountered the same issue, but after making modifications, it was able to run successfully. However, I am now facing an "index error".

# This code is adapted from Issue [#147](https://github.com/VainF/Torch-Pruning/issues/147), implemented by @Hyunseok-Kim0.
import argparse
import math
import os
from copy import deepcopy
from datetime import datetime
from pathlib import Path
from typing import List, Union

import dill
import numpy as np
import torch
import torch.nn as nn
from matplotlib import pyplot as plt
from ultralytics import YOLO, __version__
from ultralytics.cfg import TASK2DATA, get_cfg, get_save_dir
from ultralytics.nn.modules import Detect, C2f, Conv, Bottleneck
from ultralytics.nn.tasks import attempt_load_one_weight
from ultralytics.engine.trainer import BaseTrainer
from ultralytics.utils import yaml_load, LOGGER, RANK, DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS
from ultralytics.utils import (
    ARGV,
    ASSETS,
    DEFAULT_CFG_DICT,
    LOGGER,
    RANK,
    callbacks,
    checks,
    emojis,
    yaml_load,
)
from ultralytics.utils.torch_utils import initialize_weights, de_parallel
from ultralytics.engine.model import Model
from ultralytics.models.yolo.model import DetectionModel
import torch_pruning as tp


def save_pruning_performance_graph(x, y1, y2, y3):
    """
    Draw performance change graph
    Parameters
    ----------
    x : List
        Parameter numbers of all pruning steps
    y1 : List
        mAPs after fine-tuning of all pruning steps
    y2 : List
        MACs of all pruning steps
    y3 : List
        mAPs after pruning (not fine-tuned) of all pruning steps

    Returns
    -------

    """
    try:
        plt.style.use("ggplot")
    except:
        pass

    x, y1, y2, y3 = np.array(x), np.array(y1), np.array(y2), np.array(y3)
    y2_ratio = y2 / y2[0]

    # create the figure and the axis object
    fig, ax = plt.subplots(figsize=(8, 6))

    # plot the pruned mAP and recovered mAP
    ax.set_xlabel('Pruning Ratio')
    ax.set_ylabel('mAP')
    ax.plot(x, y1, label='recovered mAP')
    ax.scatter(x, y1)
    ax.plot(x, y3, color='tab:gray', label='pruned mAP')
    ax.scatter(x, y3, color='tab:gray')

    # create a second axis that shares the same x-axis
    ax2 = ax.twinx()

    # plot the second set of data
    ax2.set_ylabel('MACs')
    ax2.plot(x, y2_ratio, color='tab:orange', label='MACs')
    ax2.scatter(x, y2_ratio, color='tab:orange')

    # add a legend
    lines, labels = ax.get_legend_handles_labels()
    lines2, labels2 = ax2.get_legend_handles_labels()
    ax2.legend(lines + lines2, labels + labels2, loc='best')

    ax.set_xlim(105, -5)
    ax.set_ylim(0, max(y1) + 0.05)
    ax2.set_ylim(0.05, 1.05)

    # calculate the highest and lowest points for each set of data
    max_y1_idx = np.argmax(y1)
    min_y1_idx = np.argmin(y1)
    max_y2_idx = np.argmax(y2)
    min_y2_idx = np.argmin(y2)
    max_y1 = y1[max_y1_idx]
    min_y1 = y1[min_y1_idx]
    max_y2 = y2_ratio[max_y2_idx]
    min_y2 = y2_ratio[min_y2_idx]

    # add text for the highest and lowest values near the points
    ax.text(x[max_y1_idx], max_y1 - 0.05, f'max mAP = {max_y1:.2f}', fontsize=10)
    ax.text(x[min_y1_idx], min_y1 + 0.02, f'min mAP = {min_y1:.2f}', fontsize=10)
    ax2.text(x[max_y2_idx], max_y2 - 0.05, f'max MACs = {max_y2 * y2[0] / 1e9:.2f}G', fontsize=10)
    ax2.text(x[min_y2_idx], min_y2 + 0.02, f'min MACs = {min_y2 * y2[0] / 1e9:.2f}G', fontsize=10)

    plt.title('Comparison of mAP and MACs with Pruning Ratio')
    plt.savefig('pruning_perf_change.png')


def infer_shortcut(bottleneck):
    c1 = bottleneck.cv1.conv.in_channels
    c2 = bottleneck.cv2.conv.out_channels
    return c1 == c2 and hasattr(bottleneck, 'add') and bottleneck.add


class C2f_v2(nn.Module):
    # CSP Bottleneck with 2 convolutions
    def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5):  # ch_in, ch_out, number, shortcut, groups, expansion
        super().__init__()
        self.c = int(c2 * e)  # hidden channels
        self.cv0 = Conv(c1, self.c, 1, 1)
        self.cv1 = Conv(c1, self.c, 1, 1)
        self.cv2 = Conv((2 + n) * self.c, c2, 1)  # optional act=FReLU(c2)
        self.m = nn.ModuleList(Bottleneck(self.c, self.c, shortcut, g, k=((3, 3), (3, 3)), e=1.0) for _ in range(n))

    def forward(self, x):
        # y = list(self.cv1(x).chunk(2, 1))
        y = [self.cv0(x), self.cv1(x)]
        y.extend(m(y[-1]) for m in self.m)
        return self.cv2(torch.cat(y, 1))


def transfer_weights(c2f, c2f_v2):
    c2f_v2.cv2 = c2f.cv2
    c2f_v2.m = c2f.m

    state_dict = c2f.state_dict()
    state_dict_v2 = c2f_v2.state_dict()

    # Transfer cv1 weights from C2f to cv0 and cv1 in C2f_v2
    old_weight = state_dict['cv1.conv.weight']
    half_channels = old_weight.shape[0] // 2
    state_dict_v2['cv0.conv.weight'] = old_weight[:half_channels]
    state_dict_v2['cv1.conv.weight'] = old_weight[half_channels:]

    # Transfer cv1 batchnorm weights and buffers from C2f to cv0 and cv1 in C2f_v2
    for bn_key in ['weight', 'bias', 'running_mean', 'running_var']:
        old_bn = state_dict[f'cv1.bn.{bn_key}']
        state_dict_v2[f'cv0.bn.{bn_key}'] = old_bn[:half_channels]
        state_dict_v2[f'cv1.bn.{bn_key}'] = old_bn[half_channels:]

    # Transfer remaining weights and buffers
    for key in state_dict:
        if not key.startswith('cv1.'):
            state_dict_v2[key] = state_dict[key]

    # Transfer all non-method attributes
    for attr_name in dir(c2f):
        attr_value = getattr(c2f, attr_name)
        if not callable(attr_value) and '_' not in attr_name:
            setattr(c2f_v2, attr_name, attr_value)

    c2f_v2.load_state_dict(state_dict_v2)


def replace_c2f_with_c2f_v2(module):
    for name, child_module in module.named_children():
        if isinstance(child_module, C2f):
            # Replace C2f with C2f_v2 while preserving its parameters
            shortcut = infer_shortcut(child_module.m[0])
            c2f_v2 = C2f_v2(child_module.cv1.conv.in_channels, child_module.cv2.conv.out_channels,
                            n=len(child_module.m), shortcut=shortcut,
                            g=child_module.m[0].cv2.conv.groups,
                            e=child_module.c / child_module.cv2.conv.out_channels)
            transfer_weights(child_module, c2f_v2)
            setattr(module, name, c2f_v2)
        else:
            replace_c2f_with_c2f_v2(child_module)


def save_model_v2(self: BaseTrainer):
    """Save model training checkpoints with additional metadata."""
    import io

    import pandas as pd  # scope for faster 'import ultralytics'

    # Serialize ckpt to a byte buffer once (faster than repeated torch.save() calls)
    buffer = io.BytesIO()
    torch.save(
        {
            "epoch": self.epoch,
            "best_fitness": self.best_fitness,
            "model": None,  # resume and final checkpoints derive from EMA
            "ema": deepcopy(self.ema.ema),
            "updates": self.ema.updates,
            "optimizer": deepcopy(self.optimizer.state_dict()),
            "train_args": vars(self.args),  # save as dict
            "train_metrics": {**self.metrics, **{"fitness": self.fitness}},
            "train_results": {k.strip(): v for k, v in pd.read_csv(self.csv).to_dict(orient="list").items()},
            "date": datetime.now().isoformat(),
            "version": __version__,
            "license": "AGPL-3.0 (https://ultralytics.com/license)",
            "docs": "https://docs.ultralytics.com",
        },
        buffer,
    )
    serialized_ckpt = buffer.getvalue()  # get the serialized content to save

    # Save checkpoints
    self.last.write_bytes(serialized_ckpt)  # save last.pt
    if self.best_fitness == self.fitness:
        self.best.write_bytes(serialized_ckpt)  # save best.pt
    if (self.save_period > 0) and (self.epoch > 0) and (self.epoch % self.save_period == 0):
        (self.wdir / f"epoch{self.epoch}.pt").write_bytes(serialized_ckpt)  # save epoch, i.e. 'epoch3.pt'


def final_eval_v2(self: BaseTrainer):
    """Performs final evaluation and validation for object detection YOLO model."""
    for f in self.last, self.best:
        if f.exists():
            strip_optimizer_v2(f)  # strip optimizers
            if f is self.best:
                LOGGER.info(f"\nValidating {f}...")
                self.validator.args.plots = self.args.plots
                self.metrics = self.validator(model=f)
                self.metrics.pop("fitness", None)
                self.run_callbacks("on_fit_epoch_end")


def strip_optimizer_v2(f: Union[str, Path] = "best.pt", s: str = "") -> None:
    """
    Strip optimizer from 'f' to finalize training, optionally save as 's'.

    Args:
        f (str): file path to model to strip the optimizer from. Default is 'best.pt'.
        s (str): file path to save the model with stripped optimizer to. If not provided, 'f' will be overwritten.

    Returns:
        None

    Example:
        ```python
        from pathlib import Path
        from ultralytics.utils.torch_utils import strip_optimizer

        for f in Path('path/to/model/checkpoints').rglob('*.pt'):
            strip_optimizer(f)
        ```
    """
    try:
        x = torch.load(f, map_location=torch.device("cpu"))
        assert isinstance(x, dict), "checkpoint is not a Python dictionary"
        assert "model" in x, "'model' missing from checkpoint"
    except Exception as e:
        LOGGER.warning(f"WARNING ⚠️ Skipping {f}, not a valid Ultralytics model: {e}")
        return

    updates = {
        "date": datetime.now().isoformat(),
        "version": __version__,
        "license": "AGPL-3.0 License (https://ultralytics.com/license)",
        "docs": "https://docs.ultralytics.com",
    }

    # Update model
    if x.get("ema"):
        x["model"] = x["ema"]  # replace model with EMA
    if hasattr(x["model"], "args"):
        x["model"].args = dict(x["model"].args)  # convert from IterableSimpleNamespace to dict
    if hasattr(x["model"], "criterion"):
        x["model"].criterion = None  # strip loss criterion
    # x["model"].half()  # to FP16
    for p in x["model"].parameters():
        p.requires_grad = False

    # Update other keys
    args = {**DEFAULT_CFG_DICT, **x.get("train_args", {})}  # combine args
    for k in "optimizer", "best_fitness", "ema", "updates":  # keys
        x[k] = None
    x["epoch"] = -1
    x["train_args"] = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS}  # strip non-default keys
    # x['model'].args = x['train_args']

    # Save
    torch.save({**updates, **x}, s or f, use_dill=False)  # combine dicts (prefer to the right)
    mb = os.path.getsize(s or f) / 1e6  # file size
    LOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB")


def train_v2(self: YOLO, trainer=None, pruning=False, **kwargs):
    """
    Disabled loading new model when pruning flag is set. originated from ultralytics/yolo/engine/model.py
    """

    self._check_is_pytorch_model()
    if hasattr(self.session, "model") and self.session.model.id:  # Ultralytics HUB session with loaded model
        if any(kwargs):
            LOGGER.warning("WARNING ⚠️ using HUB training arguments, ignoring local training arguments.")
        kwargs = self.session.train_args  # overwrite kwargs

    checks.check_pip_update_available()

    overrides = yaml_load(checks.check_yaml(kwargs["cfg"])) if kwargs.get("cfg") else self.overrides
    custom = {
        # NOTE: handle the case when 'cfg' includes 'data'.
        "data": overrides.get("data") or DEFAULT_CFG_DICT["data"] or TASK2DATA[self.task],
        "model": self.overrides["model"],
        "task": self.task,
    }  # method defaults
    args = {**overrides, **custom, **kwargs, "mode": "train"}  # highest priority args on the right
    if args.get("resume"):
        args["resume"] = self.ckpt_path

    self.trainer = (trainer or self._smart_load("trainer"))(overrides=args, _callbacks=self.callbacks)

    if not pruning:
        if not overrides.get('resume'):  # manually set model only if not resuming
            self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)
            self.model = self.trainer.model

    else:
        # pruning mode
        self.trainer.pruning = True
        self.trainer.model = self.model

        # replace some functions to disable half precision saving
        self.trainer.__setattr__("save_model", save_model_v2.__get__(self.trainer))
        self.trainer.__setattr__("final_eval", final_eval_v2.__get__(self.trainer))
        # self.trainer.save_model = save_model_v2.__get__(self.trainer)
        # self.trainer.final_eval = final_eval_v2.__get__(self.trainer)

    self.trainer.hub_session = self.session  # attach optional HUB session
    self.trainer.train()
    # Update model and cfg after training
    if RANK in {-1, 0}:
        ckpt = self.trainer.best if self.trainer.best.exists() else self.trainer.last
        self.model, _ = attempt_load_one_weight(ckpt)
        self.overrides = self.model.args
        self.metrics = getattr(self.trainer.validator, "metrics", None)  # TODO: no metrics returned by DDP
    return self.metrics


def prune(args):
    # load trained yolov8 model
    model = YOLO(args.model)
    model.__setattr__("train", train_v2.__get__(model))
    pruning_cfg = yaml_load(checks.check_yaml(args.cfg))
    batch_size = pruning_cfg['batch']

    # use coco128 dataset for 10 epochs fine-tuning each pruning iteration step
    # this part is only for sample code, number of epochs should be included in config file
    pruning_cfg['data'] = "data.yaml"
    pruning_cfg['epochs'] = 1

    model.model.train()
    replace_c2f_with_c2f_v2(model.model)
    initialize_weights(model.model)  # set BN.eps, momentum, ReLU.inplace

    for name, param in model.model.named_parameters():
        param.requires_grad = True

    example_inputs = torch.randn(1, 3, pruning_cfg["imgsz"], pruning_cfg["imgsz"]).to(model.device)
    macs_list, nparams_list, map_list, pruned_map_list = [], [], [], []
    base_macs, base_nparams = tp.utils.count_ops_and_params(model.model, example_inputs)

    # do validation before pruning model
    pruning_cfg['name'] = f"baseline_val"
    pruning_cfg['batch'] = 1
    validation_model = deepcopy(model)
    metric = validation_model.val(**pruning_cfg)
    init_map = metric.box.map
    macs_list.append(base_macs)
    nparams_list.append(100)
    map_list.append(init_map)
    pruned_map_list.append(init_map)
    print(f"Before Pruning: MACs={base_macs / 1e9: .5f} G, #Params={base_nparams / 1e6: .5f} M, mAP={init_map: .5f}")

    # prune same ratio of filter based on initial size
    pruning_ratio = 1 - math.pow((1 - args.target_prune_rate), 1 / args.iterative_steps)

    for i in range(args.iterative_steps):
        model.info()
        model.model.train()
        for name, param in model.model.named_parameters():
            param.requires_grad = True

        ignored_layers = []
        unwrapped_parameters = []
        for m in model.model.modules():
            if isinstance(m, (Detect,)):
                ignored_layers.append(m)

        example_inputs = example_inputs.to(model.device)
        pruner = tp.pruner.MagnitudePruner(
            model.model,
            example_inputs,
            importance=tp.importance.MagnitudeImportance(p=2),
            iterative_steps=1,
            pruning_ratio=pruning_ratio,  # remove 50% channels
            ignored_layers=ignored_layers,
            unwrapped_parameters=unwrapped_parameters
        )

        # Test regularization
        # output = model.model(example_inputs)
        # (output[0].sum() + sum([o.sum() for o in output[1]])).backward()
        # pruner.regularize(model.model)

        pruner.step()

        # pre fine-tuning validation
        pruning_cfg['name'] = f"step_{i}_pre_val"
        pruning_cfg['batch'] = 1
        validation_model.model = deepcopy(model.model)
        metric = validation_model.val(**pruning_cfg)
        pruned_map = metric.box.map
        pruned_macs, pruned_nparams = tp.utils.count_ops_and_params(pruner.model, example_inputs.to(model.device))
        current_speed_up = float(macs_list[0]) / pruned_macs
        print(f"After pruning iter {i + 1}: MACs={pruned_macs / 1e9} G, #Params={pruned_nparams / 1e6} M, "
              f"mAP={pruned_map}, speed up={current_speed_up}")

        # fine-tuning
        for name, param in model.model.named_parameters():
            param.requires_grad = True
        pruning_cfg['name'] = f"step_{i}_finetune"
        pruning_cfg['batch'] = batch_size  # restore batch size
        model.train(**pruning_cfg)

        # post fine-tuning validation
        pruning_cfg['name'] = f"step_{i}_post_val"
        pruning_cfg['batch'] = 1
        validation_model = YOLO(model.trainer.best)
        metric = validation_model.val(**pruning_cfg)
        current_map = metric.box.map
        print(f"After fine tuning mAP={current_map}")

        macs_list.append(pruned_macs)
        nparams_list.append(pruned_nparams / base_nparams * 100)
        pruned_map_list.append(pruned_map)
        map_list.append(current_map)

        # remove pruner after single iteration
        del pruner

        save_pruning_performance_graph(nparams_list, map_list, macs_list, pruned_map_list)

        if init_map - current_map > args.max_map_drop:
            print("Pruning early stop")
            break

    model.export(format='ncnn', half=True)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', default='runs/drink_auto/weights/best.pt',
                        help='Pretrained pruning target model file')
    parser.add_argument('--cfg', default='default.yaml',
                        help='Pruning config file.'
                             ' This file should have same format with ultralytics/yolo/cfg/default.yaml')
    parser.add_argument('--iterative-steps', default=16, type=int, help='Total pruning iteration step')
    parser.add_argument('--target-prune-rate', default=0.5, type=float, help='Target pruning rate')
    parser.add_argument('--max-map-drop', default=1.0, type=float, help='Allowed maximum map drop after fine-tuning')

    args = parser.parse_args()

    prune(args)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants