Skip to content

Latest commit

 

History

History
466 lines (359 loc) · 10.2 KB

源码内容分析.md

File metadata and controls

466 lines (359 loc) · 10.2 KB
def train_epoch(
        epoch, model, loader, optimizer, loss_fn, cfg,
        lr_scheduler=None, saver=None, output_dir='', use_amp=False,
        model_ema=None, logger=None, writer=None, local_rank=0):

对选定分路径进行进一步训练并输出日志,返回残差

def validate(epoch, model, loader, loss_fn, cfg, log_suffix='', logger=None, writer=None, local_rank=0):

预测函数,输出日志并返回残差与精确度

def train_epoch(epoch, model, loader, optimizer, loss_fn, prioritized_board, MetaMN, cfg,
                est=None, logger=None, lr_scheduler=None, saver=None,
                output_dir='', model_ema=None, local_rank=0):

进行完整训练流程并输出日志,返回残差

def validate(model, loader, loss_fn, prioritized_board, cfg, log_suffix='', local_rank=0, logger=None):

进行完整检验流程并输出日志,返回残差与精确度

class InvertedResidual(nn.Module)

复制的timm源码,构造一个深度广度网络,结构为“点卷积-深度卷积-压缩卷积-线性层”

    def feature_info(self, location):

功能未知,似乎是调用网络模块的信息,但好像并没有被调用过

    def forward(self, x):

组网方法


def conv3x3(in_planes, out_planes, stride=1):

定义3*3卷积层函数

class BasicBlock(nn.Module):

定义一个结构为“3*3卷积-3*3卷积”的网络模块

    def forward(self, x):

组网方法

class BasicBlock(nn.Module):

定义一个结构为“1*1卷积-3*3卷积-1*1卷积”的网络模块

    def forward(self, x):

组网方法

def get_Bottleneck(in_c, out_c, stride):

Bottleneck构造函数

def get_BasicBlock(in_c, out_c, stride):

BasicBlock构造函数

class ChildNetBuilder:

分路径中可选部分的网络(包含选定)类

    def _round_channels(self, chs):

保护方法,计算?

    def _make_block(self, ba, block_idx, block_count):

保护方法,构造网络模块

    def __call__(self, in_chs, model_block_args):

重写的调用方法

class SuperNetBuilder:

超网络中可选部分的网络类

    def _round_channels(self, chs):

保护方法,计算?

    def _make_block(self, ba, block_idx, block_count):

保护方法,构造网络模块

    def __call__(self, in_chs, model_block_args):

重写的调用方法

class ChildNet(nn.Module):

分路径类

    def get_classifier(self):

getter方法,返回网络末端的全连接分类器

    def reset_classifier(self, num_classes, global_pool='avg'):

根据新的分类数重设网络末端的全连接分类器

    def forward_features(self, x):

除全连接分类器外的组网

    def forward(self, x):

全网络组网

def gen_childnet(arch_list, arch_def, **kwargs):

根据给定路径选项构造分路径

class SuperNet(nn.Module):

超网络类

    def get_classifier(self):

getter方法,返回网络末端的全连接分类器

    def reset_classifier(self, num_classes, global_pool='avg'):

根据新的分类数重设网络末端的全连接分类器

    def forward_features(self, x):

除全连接分类器外的组网

    def forward(self, x):

全网络组网

    def forward_meta(self, features):

getter方法,返回匹配网络本体

    def rand_parameters(self, architecture, meta=False):

getter方法,返回随机层参数,用于网络训练

class Classifier(nn.Module):

1000*1000全连接分类器,不知道干什么用的

def gen_supernet(flops_minimum=0, flops_maximum=600, **kwargs):

生成超网络函数

class MetaMatchingNetwork():

匹配网络类,可实现匹配网络相关功能,但不包含网络本体

    def update_student_weights_only(self, random_cand, grad_1, optimizer, model):

与类无关的方法,仅对输入的分路径所包含的权重进行一次训练更新

    def update_meta_weights_only(self, random_cand, teacher_cand, model, optimizer, grad_teacher):

仅更新匹配网络权重的函数

    def simulate_sgd_update(self, w, g, optimizer):

模拟梯度下降步函数,返回变量张量

    def calculate_1st_gradient(self, kd_loss, model, random_cand, optimizer):

梯度计算(针对训练集残差$L_{CE}$的梯度)

    def calculate_2nd_gradient(self, validation_loss, model, optimizer, random_cand, teacher_cand, students_weight):

梯度计算(针对知识蒸馏,即先进路径的残差$L_{KD}$的梯度)

    def forward_training(self, x, model, random_cand, teacher_cand, meta_value):

分路径与先进路径对比残差$L_{KD}$的计算函数

    def forward_training(self, x, model, random_cand, teacher_cand, meta_value):

训练样本残差$L_{CE}$的计算函数

    def isUpdate(self, current_epoch, batch_idx, prioritized_board):

检测是否本轮训练所有权重更新均已完成

    def run_update(self, input, target, random_cand, model, optimizer,

单轮训练流程(默认包含匹配网络训练)

class PrioritizedBoard():

先进路径库

    def select_teacher(self, model, random_cand):

为分路径匹配互补的先进路径

    def board_size(self):

先进路径库大小

    def get_prob(self):

获取路径广义softmax输出

    def get_cand_with_prob(self, prob=None):

根据广义softmax输出获取路径

    def isUpdate(self, current_epoch, prec1, flops):

检测是否本轮先进路径库所有更新操作都已完成

    def update_prioritized_board(self, inputs, teacher_output, outputs, current_epoch, prec1, flops, cand):

更新先进路径库

def parse_ksize(ss):

数据类型转换函数,似乎是用于处理字符串形式网络架构标记,将其中的卷积核大小转换为数字的

def decode_arch_def(
        arch_def,
        depth_multiplier=1.0,
        depth_trunc='ceil',
        experts_multiplier=1):

将神经网络结构字符串转化为结构字典

def modify_block_args(block_args, kernel_size, exp_ratio):

对模块参数进行设置

def decode_block_str(block_str):

解析模块结构字符串(一个分路径包含多个模块)

def scale_stage_depth(
        stack_args,
        repeats,
        depth_multiplier=1.0,
        depth_trunc='ceil'):

根据深度缩放系数(depth_multiplier)重设各个模块重复次数

def init_weight_goog(m, n='', fix_group_fanout=True, last_bn=None):

模块权重初始化函数

def init_weight_goog(m, n='', fix_group_fanout=True, last_bn=None):

EfficientNet权重初始化函数

class FlopsEst(object):

网络模型复杂度评估类

    def get_params(self, arch):

权重数据量获取方法

    def get_flops(self, arch):

加/乘运算次数获取方法

def search_for_layer(flops_op_dict, arch_def, flops_minimum, flops_maximum):

按照给定模型复杂度范围搜索可用分路径

def get_path_acc(model, path, val_loader, args, val_iters=50):

计算分路径准确度

def get_logger(file_path):

日志生成函数

def add_weight_decay_supernet(model, args, weight_decay=1e-5, skip_list=()):

设置超网络权重衰减

def create_optimizer_supernet(args, model, has_apex, filter_bias_and_bn=True):

设置超网络模型优化器

def convert_lowercase(cfg):

大小写转换器,用于将yaml设置文件中的大写变量名转换为小写

def parse_config_args(exp_name):

设置文件解码器,等效于将设置文件内容导入Cream\lib\config.py后将所有设置一并导出

def get_model_flops_params(model, input_size=(1, 3, 224, 224)):

计算模型总参数量与总计算量

def cross_entropy_loss_with_soft_target(pred, soft_target):
    logsoftmax = nn.LogSoftmax(dim=1)
    return torch.mean(torch.sum(- soft_target * logsoftmax(pred), 1))

特殊交叉熵函数(知识蒸馏用)

def create_supernet_scheduler(cfg, optimizer):

获取超网络训练时间安排

def add_path(path):

将现有路径添加到python目录

def main():

主运行函数,用于处理交互命令

def main():

再训练主函数,用于发起单个分路径的训练流程

def main():

检验主函数,用于发起利用验证集验证模型的流程

def main():

训练主函数,用于发起训练流程