def train_epoch(
epoch, model, loader, optimizer, loss_fn, cfg,
lr_scheduler=None, saver=None, output_dir='', use_amp=False,
model_ema=None, logger=None, writer=None, local_rank=0):
对选定分路径进行进一步训练并输出日志,返回残差
def validate(epoch, model, loader, loss_fn, cfg, log_suffix='', logger=None, writer=None, local_rank=0):
预测函数,输出日志并返回残差与精确度
def train_epoch(epoch, model, loader, optimizer, loss_fn, prioritized_board, MetaMN, cfg,
est=None, logger=None, lr_scheduler=None, saver=None,
output_dir='', model_ema=None, local_rank=0):
进行完整训练流程并输出日志,返回残差
def validate(model, loader, loss_fn, prioritized_board, cfg, log_suffix='', local_rank=0, logger=None):
进行完整检验流程并输出日志,返回残差与精确度
class InvertedResidual(nn.Module)
复制的timm源码,构造一个深度广度网络,结构为“点卷积-深度卷积-压缩卷积-线性层”
def feature_info(self, location):
功能未知,似乎是调用网络模块的信息,但好像并没有被调用过
def forward(self, x):
组网方法
def conv3x3(in_planes, out_planes, stride=1):
定义3*3卷积层函数
class BasicBlock(nn.Module):
定义一个结构为“3*3卷积-3*3卷积”的网络模块
def forward(self, x):
组网方法
class BasicBlock(nn.Module):
定义一个结构为“1*1卷积-3*3卷积-1*1卷积”的网络模块
def forward(self, x):
组网方法
def get_Bottleneck(in_c, out_c, stride):
Bottleneck构造函数
def get_BasicBlock(in_c, out_c, stride):
BasicBlock构造函数
class ChildNetBuilder:
分路径中可选部分的网络(包含选定)类
def _round_channels(self, chs):
保护方法,计算?
def _make_block(self, ba, block_idx, block_count):
保护方法,构造网络模块
def __call__(self, in_chs, model_block_args):
重写的调用方法
class SuperNetBuilder:
超网络中可选部分的网络类
def _round_channels(self, chs):
保护方法,计算?
def _make_block(self, ba, block_idx, block_count):
保护方法,构造网络模块
def __call__(self, in_chs, model_block_args):
重写的调用方法
class ChildNet(nn.Module):
分路径类
def get_classifier(self):
getter方法,返回网络末端的全连接分类器
def reset_classifier(self, num_classes, global_pool='avg'):
根据新的分类数重设网络末端的全连接分类器
def forward_features(self, x):
除全连接分类器外的组网
def forward(self, x):
全网络组网
def gen_childnet(arch_list, arch_def, **kwargs):
根据给定路径选项构造分路径
class SuperNet(nn.Module):
超网络类
def get_classifier(self):
getter方法,返回网络末端的全连接分类器
def reset_classifier(self, num_classes, global_pool='avg'):
根据新的分类数重设网络末端的全连接分类器
def forward_features(self, x):
除全连接分类器外的组网
def forward(self, x):
全网络组网
def forward_meta(self, features):
getter方法,返回匹配网络本体
def rand_parameters(self, architecture, meta=False):
getter方法,返回随机层参数,用于网络训练
class Classifier(nn.Module):
1000*1000全连接分类器,不知道干什么用的
def gen_supernet(flops_minimum=0, flops_maximum=600, **kwargs):
生成超网络函数
class MetaMatchingNetwork():
匹配网络类,可实现匹配网络相关功能,但不包含网络本体
def update_student_weights_only(self, random_cand, grad_1, optimizer, model):
与类无关的方法,仅对输入的分路径所包含的权重进行一次训练更新
def update_meta_weights_only(self, random_cand, teacher_cand, model, optimizer, grad_teacher):
仅更新匹配网络权重的函数
def simulate_sgd_update(self, w, g, optimizer):
模拟梯度下降步函数,返回变量张量
def calculate_1st_gradient(self, kd_loss, model, random_cand, optimizer):
梯度计算(针对训练集残差$L_{CE}$的梯度)
def calculate_2nd_gradient(self, validation_loss, model, optimizer, random_cand, teacher_cand, students_weight):
梯度计算(针对知识蒸馏,即先进路径的残差$L_{KD}$的梯度)
def forward_training(self, x, model, random_cand, teacher_cand, meta_value):
分路径与先进路径对比残差$L_{KD}$的计算函数
def forward_training(self, x, model, random_cand, teacher_cand, meta_value):
训练样本残差$L_{CE}$的计算函数
def isUpdate(self, current_epoch, batch_idx, prioritized_board):
检测是否本轮训练所有权重更新均已完成
def run_update(self, input, target, random_cand, model, optimizer,
单轮训练流程(默认包含匹配网络训练)
class PrioritizedBoard():
先进路径库
def select_teacher(self, model, random_cand):
为分路径匹配互补的先进路径
def board_size(self):
先进路径库大小
def get_prob(self):
获取路径广义softmax输出
def get_cand_with_prob(self, prob=None):
根据广义softmax输出获取路径
def isUpdate(self, current_epoch, prec1, flops):
检测是否本轮先进路径库所有更新操作都已完成
def update_prioritized_board(self, inputs, teacher_output, outputs, current_epoch, prec1, flops, cand):
更新先进路径库
def parse_ksize(ss):
数据类型转换函数,似乎是用于处理字符串形式网络架构标记,将其中的卷积核大小转换为数字的
def decode_arch_def(
arch_def,
depth_multiplier=1.0,
depth_trunc='ceil',
experts_multiplier=1):
将神经网络结构字符串转化为结构字典
def modify_block_args(block_args, kernel_size, exp_ratio):
对模块参数进行设置
def decode_block_str(block_str):
解析模块结构字符串(一个分路径包含多个模块)
def scale_stage_depth(
stack_args,
repeats,
depth_multiplier=1.0,
depth_trunc='ceil'):
根据深度缩放系数(depth_multiplier)重设各个模块重复次数
def init_weight_goog(m, n='', fix_group_fanout=True, last_bn=None):
模块权重初始化函数
def init_weight_goog(m, n='', fix_group_fanout=True, last_bn=None):
EfficientNet权重初始化函数
class FlopsEst(object):
网络模型复杂度评估类
def get_params(self, arch):
权重数据量获取方法
def get_flops(self, arch):
加/乘运算次数获取方法
def search_for_layer(flops_op_dict, arch_def, flops_minimum, flops_maximum):
按照给定模型复杂度范围搜索可用分路径
def get_path_acc(model, path, val_loader, args, val_iters=50):
计算分路径准确度
def get_logger(file_path):
日志生成函数
def add_weight_decay_supernet(model, args, weight_decay=1e-5, skip_list=()):
设置超网络权重衰减
def create_optimizer_supernet(args, model, has_apex, filter_bias_and_bn=True):
设置超网络模型优化器
def convert_lowercase(cfg):
大小写转换器,用于将yaml设置文件中的大写变量名转换为小写
def parse_config_args(exp_name):
设置文件解码器,等效于将设置文件内容导入Cream\lib\config.py后将所有设置一并导出
def get_model_flops_params(model, input_size=(1, 3, 224, 224)):
计算模型总参数量与总计算量
def cross_entropy_loss_with_soft_target(pred, soft_target):
logsoftmax = nn.LogSoftmax(dim=1)
return torch.mean(torch.sum(- soft_target * logsoftmax(pred), 1))
特殊交叉熵函数(知识蒸馏用)
def create_supernet_scheduler(cfg, optimizer):
获取超网络训练时间安排
def add_path(path):
将现有路径添加到python目录
def main():
主运行函数,用于处理交互命令
def main():
再训练主函数,用于发起单个分路径的训练流程
def main():
检验主函数,用于发起利用验证集验证模型的流程
def main():
训练主函数,用于发起训练流程