diff --git a/configs/_base_/models/asformer.py b/configs/_base_/models/asformer.py new file mode 100644 index 0000000000..6468449ada --- /dev/null +++ b/configs/_base_/models/asformer.py @@ -0,0 +1,12 @@ +# model settings +model = dict( + type='ASFormer', + num_layers=10, + num_f_maps=64, + input_dim=2048, + num_decoders=3, + num_classes=11, + channel_masking_rate=0.5, + sample_rate=1, + r1=2, + r2=2) diff --git a/configs/segmentation/asformer/asformer_gtea.py b/configs/segmentation/asformer/asformer_gtea.py new file mode 100644 index 0000000000..ec9f5d595d --- /dev/null +++ b/configs/segmentation/asformer/asformer_gtea.py @@ -0,0 +1,106 @@ +_base_ = ['../../_base_/models/asformer.py', '../../_base_/default_runtime.py'] +dataset_type = 'ActionSegmentDataset' +data_root = 'data/gtea/csv_mean_100/' +data_root_val = 'data/action_seg/gtea/' +ann_file_train = 'data/action_seg/gtea/splits/train.split1.bundle' +ann_file_val = 'data/action_seg/gtea/splits/test.split1.bundle' + +ann_file_test = 'data/action_seg/gtea/splits/test.split1.bundle' + +train_pipeline = [ + dict(type='LoadSegmentationFeature'), + dict(type='GenerateSegmentationLabels'), + dict( + type='PackLocalizationInputs', + keys=('gt_bbox', ), + meta_keys=('video_name', )) +] + +val_pipeline = [ + dict(type='LoadSegmentationFeature'), + dict(type='GenerateSegmentationLabels'), + dict( + type='PackLocalizationInputs', + keys=('gt_bbox', ), + meta_keys=('video_name', 'duration_second', 'duration_frame', + 'annotations', 'feature_frame')) +] + +test_pipeline = [ + dict(type='LoadSegmentationFeature'), + dict( + type='PackSegmentationInputs', + keys=('classes', ), + meta_keys=('num_classes', 'actions_dict', 'index2label', + 'ground_truth', 'classes')) +] + +train_dataloader = dict( + batch_size=8, + num_workers=8, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=True), + drop_last=True, + dataset=dict( + type=dataset_type, + ann_file=ann_file_train, + data_prefix=dict(video=data_root), + pipeline=train_pipeline)) + +val_dataloader = dict( + batch_size=1, + num_workers=8, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict( + type=dataset_type, + ann_file=ann_file_val, + data_prefix=dict(video=data_root_val), + pipeline=val_pipeline, + test_mode=True)) + +test_dataloader = dict( + batch_size=1, + num_workers=8, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict( + type=dataset_type, + ann_file=ann_file_test, + data_prefix=dict(video=data_root_val), + pipeline=test_pipeline, + test_mode=True)) + +max_epochs = 9 +train_cfg = dict( + type='EpochBasedTrainLoop', + max_epochs=max_epochs, + val_begin=1, + val_interval=1) + +val_cfg = dict(type='ValLoop') +test_cfg = dict(type='TestLoop') + +optim_wrapper = dict( + optimizer=dict(type='Adam', lr=0.001, weight_decay=0.0001), + clip_grad=dict(max_norm=40, norm_type=2)) + +param_scheduler = [ + dict( + type='MultiStepLR', + begin=0, + end=max_epochs, + by_epoch=True, + milestones=[ + 7, + ], + gamma=0.1) +] + +work_dir = './work_dirs/bmn_400x100_2x8_9e_activitynet_feature/' +load_from = './work_dirs/bmn_400x100_2x8_9e_activitynet_feature/epoch-120.pth' +test_evaluator = dict( + type='SegmentMetric', + metric_type='ALL', + dump_config=dict(out=f'{work_dir}/results.json', output_format='json')) +val_evaluator = test_evaluator diff --git a/mmaction/datasets/__init__.py b/mmaction/datasets/__init__.py index ded946b727..bb45165929 100644 --- a/mmaction/datasets/__init__.py +++ b/mmaction/datasets/__init__.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +from .action_segment_dataset import ActionSegmentDataset from .activitynet_dataset import ActivityNetDataset from .audio_dataset import AudioDataset from .ava_dataset import AVADataset, AVAKineticsDataset @@ -13,5 +14,6 @@ __all__ = [ 'AVADataset', 'AVAKineticsDataset', 'ActivityNetDataset', 'AudioDataset', 'BaseActionDataset', 'PoseDataset', 'RawframeDataset', 'RepeatAugDataset', - 'VideoDataset', 'repeat_pseudo_collate', 'VideoTextDataset' + 'VideoDataset', 'repeat_pseudo_collate', 'VideoTextDataset', + 'ActionSegmentDataset' ] diff --git a/mmaction/datasets/action_segment_dataset.py b/mmaction/datasets/action_segment_dataset.py new file mode 100644 index 0000000000..ee5a76573f --- /dev/null +++ b/mmaction/datasets/action_segment_dataset.py @@ -0,0 +1,63 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Callable, List, Optional, Union + +from mmengine.fileio import exists + +from mmaction.registry import DATASETS +from mmaction.utils import ConfigType +from .base import BaseActionDataset + + +@DATASETS.register_module() +class ActionSegmentDataset(BaseActionDataset): + + def __init__(self, + ann_file: str, + pipeline: List[Union[dict, Callable]], + data_prefix: Optional[ConfigType] = dict(video=''), + test_mode: bool = False, + **kwargs): + + super().__init__( + ann_file, + pipeline=pipeline, + data_prefix=data_prefix, + test_mode=test_mode, + **kwargs) + + def load_data_list(self) -> List[dict]: + """Load annotation file to get video information.""" + exists(self.ann_file) + file_ptr = open(self.ann_file, 'r') + list_of_examples = file_ptr.read().split('\n')[:-1] + file_ptr.close() + gts = [ + self.data_prefix['video'] + 'groundTruth/' + vid + for vid in list_of_examples + ] + features_npy = [ + self.data_prefix['video'] + 'features/' + vid.split('.')[0] + + '.npy' for vid in list_of_examples + ] + data_list = [] + + file_ptr_d = open(self.data_prefix['video'] + '/mapping.txt', 'r') + actions = file_ptr_d.read().split('\n')[:-1] + file_ptr.close() + actions_dict = dict() + for a in actions: + actions_dict[a.split()[1]] = int(a.split()[0]) + index2label = dict() + for k, v in actions_dict.items(): + index2label[v] = k + num_classes = len(actions_dict) + for idx, feature in enumerate(features_npy): + video_info = dict() + feature_path = features_npy[idx] + video_info['feature_path'] = feature_path + video_info['actions_dict'] = actions_dict + video_info['index2label'] = index2label + video_info['ground_truth_path'] = gts[idx] + video_info['num_classes'] = num_classes + data_list.append(video_info) + return data_list diff --git a/mmaction/datasets/transforms/__init__.py b/mmaction/datasets/transforms/__init__.py index f2670cd929..e71c70bdb7 100644 --- a/mmaction/datasets/transforms/__init__.py +++ b/mmaction/datasets/transforms/__init__.py @@ -1,11 +1,13 @@ # Copyright (c) OpenMMLab. All rights reserved. from .formatting import (FormatAudioShape, FormatGCNInput, FormatShape, - PackActionInputs, PackLocalizationInputs, Transpose) + PackActionInputs, PackLocalizationInputs, + PackSegmentationInputs, Transpose) from .loading import (ArrayDecode, AudioFeatureSelector, BuildPseudoClip, DecordDecode, DecordInit, DenseSampleFrames, - GenerateLocalizationLabels, ImageDecode, - LoadAudioFeature, LoadHVULabel, LoadLocalizationFeature, - LoadProposals, LoadRGBFromFile, OpenCVDecode, OpenCVInit, + GenerateLocalizationLabels, GenerateSegmentationLabels, + ImageDecode, LoadAudioFeature, LoadHVULabel, + LoadLocalizationFeature, LoadProposals, LoadRGBFromFile, + LoadSegmentationFeature, OpenCVDecode, OpenCVInit, PIMSDecode, PIMSInit, PyAVDecode, PyAVDecodeMotionVector, PyAVInit, RawFrameDecode, SampleAVAFrames, SampleFrames, UniformSample, UntrimmedSampleFrames) @@ -37,5 +39,6 @@ 'SampleAVAFrames', 'SampleFrames', 'TenCrop', 'ThreeCrop', 'ToMotion', 'TorchVisionWrapper', 'Transpose', 'UniformSample', 'UniformSampleFrames', 'UntrimmedSampleFrames', 'MMUniformSampleFrames', 'MMDecode', 'MMCompact', - 'CLIPTokenize' + 'CLIPTokenize', 'LoadSegmentationFeature', 'GenerateSegmentationLabels', + 'PackSegmentationInputs' ] diff --git a/mmaction/datasets/transforms/formatting.py b/mmaction/datasets/transforms/formatting.py index 9b9cb375a9..b6371059c4 100644 --- a/mmaction/datasets/transforms/formatting.py +++ b/mmaction/datasets/transforms/formatting.py @@ -170,6 +170,65 @@ def __repr__(self) -> str: return repr_str +@TRANSFORMS.register_module() +class PackSegmentationInputs(BaseTransform): + + def __init__(self, keys=(), meta_keys=('video_name', )): + self.keys = keys + self.meta_keys = meta_keys + + def transform(self, results): + """Method to pack the input data. + + Args: + results (dict): Result dict from the data pipeline. + + Returns: + dict: + + - 'inputs' (obj:`torch.Tensor`): The forward data of models. + - 'data_samples' (obj:`DetDataSample`): The annotation info of the + sample. + """ + packed_results = dict() + if 'raw_feature' in results: + raw_feature = results['raw_feature'] + packed_results['inputs'] = to_tensor(raw_feature) + elif 'bsp_feature' in results: + packed_results['inputs'] = torch.tensor(0.) + else: + raise ValueError( + 'Cannot get "raw_feature" or "bsp_feature" in the input ' + 'dict of `PackActionInputs`.') + + data_sample = ActionDataSample() + for key in self.keys: + if key not in results: + continue + if key == 'classes': + instance_data = InstanceData() + instance_data[key] = to_tensor(results[key]) + data_sample.gt_instances = instance_data + elif key == 'proposals': + instance_data = InstanceData() + instance_data[key] = to_tensor(results[key]) + data_sample.proposals = instance_data + else: + raise NotImplementedError( + f"Key '{key}' is not supported in `PackLocalizationInputs`" + ) + + img_meta = {k: results[k] for k in self.meta_keys if k in results} + data_sample.set_metainfo(img_meta) + packed_results['data_samples'] = data_sample + return packed_results + + def __repr__(self) -> str: + repr_str = self.__class__.__name__ + repr_str += f'(meta_keys={self.meta_keys})' + return repr_str + + @TRANSFORMS.register_module() class Transpose(BaseTransform): """Transpose image channels to a given order. diff --git a/mmaction/datasets/transforms/loading.py b/mmaction/datasets/transforms/loading.py index 22070371a1..7a9271f24d 100644 --- a/mmaction/datasets/transforms/loading.py +++ b/mmaction/datasets/transforms/loading.py @@ -1854,6 +1854,77 @@ def transform(self, results): return results +@TRANSFORMS.register_module() +class LoadSegmentationFeature(BaseTransform): + """Load Video features for Segmentation with given video_name list. + + The required key is "feature_path", added or modified keys + are "raw_feature". + + Args: + raw_feature_ext (str): Raw feature file extension. Default: '.csv'. + """ + + def transform(self, results): + """Perform the LoadSegmentationFeature loading. + + Args: + results (dict): The resulting dict to be modified and passed + to the next transform in pipeline. + """ + raw_feature = np.load(results['feature_path']) + file_ptr = open(results['ground_truth_path'], 'r') + content = file_ptr.read().split('\n')[:-1] + classes = np.zeros(min(np.shape(raw_feature)[1], len(content))) + for i in range(len(classes)): + classes[i] = results['actions_dict'][content[i]] + + results['raw_feature'] = raw_feature + results['ground_truth'] = content + results['classes'] = classes + + return results + + def __repr__(self): + repr_str = f'{self.__class__.__name__}' + return repr_str + + +@TRANSFORMS.register_module() +class GenerateSegmentationLabels(BaseTransform): + """Load video label for localizer with given video_name list. + + Required keys are "duration_frame", "duration_second", "feature_frame", + "annotations", added or modified keys are "gt_bbox". + """ + + def transform(self, results): + """Perform the GenerateLocalizationLabels loading. + + Args: + results (dict): The resulting dict to be modified and passed + to the next transform in pipeline. + """ + video_frame = results['duration_frame'] + video_second = results['duration_second'] + feature_frame = results['feature_frame'] + corrected_second = float(feature_frame) / video_frame * video_second + annotations = results['annotations'] + + gt_bbox = [] + + for annotation in annotations: + current_start = max( + min(1, annotation['segment'][0] / corrected_second), 0) + current_end = max( + min(1, annotation['segment'][1] / corrected_second), 0) + gt_bbox.append([current_start, current_end]) + + gt_bbox = np.array(gt_bbox) + results['gt_bbox'] = gt_bbox + return results + + @TRANSFORMS.register_module() class LoadProposals(BaseTransform): """Loading proposals with given proposal results. diff --git a/mmaction/evaluation/metrics/__init__.py b/mmaction/evaluation/metrics/__init__.py index 8bf22c6672..d6f47ffcce 100644 --- a/mmaction/evaluation/metrics/__init__.py +++ b/mmaction/evaluation/metrics/__init__.py @@ -4,8 +4,9 @@ from .ava_metric import AVAMetric from .multisports_metric import MultiSportsMetric from .retrieval_metric import RetrievalMetric +from .segment_metric import SegmentMetric __all__ = [ 'AccMetric', 'AVAMetric', 'ANetMetric', 'ConfusionMatrix', - 'MultiSportsMetric', 'RetrievalMetric' + 'MultiSportsMetric', 'RetrievalMetric', 'SegmentMetric' ] diff --git a/mmaction/evaluation/metrics/segment_metric.py b/mmaction/evaluation/metrics/segment_metric.py new file mode 100644 index 0000000000..97b4f05f52 --- /dev/null +++ b/mmaction/evaluation/metrics/segment_metric.py @@ -0,0 +1,264 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +import os.path as osp +from collections import OrderedDict +from typing import Any, Optional, Sequence, Tuple + +import mmcv +import mmengine +import numpy as np +from mmengine.evaluator import BaseMetric + +from mmaction.registry import METRICS +from mmaction.utils import ConfigType + + +@METRICS.register_module() +class SegmentMetric(BaseMetric): + """Action Segmentation dataset evaluation metric.""" + + def __init__(self, + metric_type: str = 'TEM', + collect_device: str = 'cpu', + prefix: Optional[str] = None, + metric_options: dict = {}, + dump_config: ConfigType = dict(out='')): + super().__init__(collect_device=collect_device, prefix=prefix) + self.metric_type = metric_type + assert 'out' in dump_config + self.output_format = dump_config.pop('output_format', 'csv') + self.out = dump_config['out'] + + self.metric_options = metric_options + if self.metric_type == 'AR@AN': + self.ground_truth = {} + + def process(self, data_batch: Sequence[Tuple[Any, dict]], + predictions: Sequence[dict]) -> None: + """Process one batch of data samples and predictions. The processed + results should be stored in ``self.results``, which will be used to + compute the metrics when all batches have been processed. + + Args: + data_batch (Sequence[Tuple[Any, dict]]): A batch of data + from the dataloader. + predictions (Sequence[dict]): A batch of outputs from + the model. + """ + for pred in predictions: + self.results.append(pred) + + if self.metric_type == 'ALL': + data_batch = data_batch['data_samples'] + + def compute_metrics(self, results: list) -> dict: + """Compute the metrics from processed results. + + Args: + results (list): The processed results of each batch. + Returns: + dict: The computed metrics. The keys are the names of the metrics, + and the values are corresponding results. + """ + if self.metric_type == 'ALL': + return self.compute_ALL(results) + return OrderedDict() + + def compute_ALL(self, results: list) -> dict: + """AR@AN evaluation metric.""" + eval_results = OrderedDict() + overlap = [.1, .25, .5] + tp, fp, fn = np.zeros(3), np.zeros(3), np.zeros(3) + + correct = 0 + total = 0 + edit = 0 + + for vid in self.results: + + gt_content = vid['ground'] + recog_content = vid['recognition'] + + for i in range(len(gt_content)): + total += 1 + if gt_content[i] == recog_content[i]: + correct += 1 + + edit += self.edit_score(recog_content, gt_content) + + for s in range(len(overlap)): + tp1, fp1, fn1 = self.f_score(recog_content, gt_content, + overlap[s]) + tp[s] += tp1 + fp[s] += fp1 + fn[s] += fn1 + eval_results['Acc'] = 100 * float(correct) / total + eval_results['Edit'] = (1.0 * edit) / len(self.results) + f1s = np.array([0, 0, 0], dtype=float) + for s in range(len(overlap)): + precision = tp[s] / float(tp[s] + fp[s]) + recall = tp[s] / float(tp[s] + fn[s]) + + f1 = 2.0 * (precision * recall) / (precision + recall) + + f1 = np.nan_to_num(f1) * 100 + f1s[s] = f1 + + eval_results['F1@10'] = f1s[0] + eval_results['F1@25'] = f1s[1] + eval_results['F1@50'] = f1s[2] + + return eval_results + + def dump_results(self, results, version='VERSION 1.3'): + """Save middle or final results to disk.""" + if self.output_format == 'json': + result_dict = self.proposals2json(results) + output_dict = { + 'version': version, + 'results': result_dict, + 'external_data': {} + } + mmengine.dump(output_dict, self.out) + elif self.output_format == 'csv': + os.makedirs(self.out, exist_ok=True) + header = 'action,start,end,tmin,tmax' + for result in results: + video_name, outputs = result + output_path = osp.join(self.out, video_name + '.csv') + np.savetxt( + output_path, + outputs, + header=header, + delimiter=',', + comments='') + else: + raise ValueError( + f'The output format {self.output_format} is not supported.') + + @staticmethod + def proposals2json(results, show_progress=False): + """Convert all proposals to a final dict(json) format. + Args: + results (list[dict]): All proposals. + show_progress (bool): Whether to show the progress bar. + Defaults: False. + Returns: + dict: The final result dict. E.g. + .. code-block:: Python + dict(video-1=[dict(segment=[1.1,2.0]. score=0.9), + dict(segment=[50.1, 129.3], score=0.6)]) + """ + result_dict = {} + print('Convert proposals to json format') + if show_progress: + prog_bar = mmcv.ProgressBar(len(results)) + for result in results: + video_name = result['video_name'] + result_dict[video_name[2:]] = result['proposal_list'] + if show_progress: + prog_bar.update() + return result_dict + + @staticmethod + def _import_proposals(results): + """Read predictions from results.""" + proposals = {} + num_proposals = 0 + for result in results: + video_id = result['video_name'][2:] + this_video_proposals = [] + for proposal in result['proposal_list']: + t_start, t_end = proposal['segment'] + score = proposal['score'] + this_video_proposals.append([t_start, t_end, score]) + num_proposals += 1 + proposals[video_id] = np.array(this_video_proposals) + return proposals, num_proposals + + def f_score(self, + recognized, + ground_truth, + overlap, + bg_class=['background']): + p_label, p_start, p_end = self.get_labels_start_end_time( + recognized, bg_class) + y_label, y_start, y_end = self.get_labels_start_end_time( + ground_truth, bg_class) + + tp = 0 + fp = 0 + + hits = np.zeros(len(y_label)) + + for j in range(len(p_label)): + intersection = np.minimum(p_end[j], y_end) - np.maximum( + p_start[j], y_start) + union = np.maximum(p_end[j], y_end) - np.minimum( + p_start[j], y_start) + IoU = (1.0 * intersection / union) * ( + [p_label[j] == y_label[x] for x in range(len(y_label))]) + idx = np.array(IoU).argmax() + + if IoU[idx] >= overlap and not hits[idx]: + tp += 1 + hits[idx] = 1 + else: + fp += 1 + fn = len(y_label) - sum(hits) + return float(tp), float(fp), float(fn) + + def edit_score(self, + recognized, + ground_truth, + norm=True, + bg_class=['background']): + P, _, _ = self.get_labels_start_end_time(recognized, bg_class) + Y, _, _ = self.get_labels_start_end_time(ground_truth, bg_class) + return self.levenstein(P, Y, norm) + + def get_labels_start_end_time(self, + frame_wise_labels, + bg_class=['background']): + labels = [] + starts = [] + ends = [] + last_label = frame_wise_labels[0] + if frame_wise_labels[0] not in bg_class: + labels.append(frame_wise_labels[0]) + starts.append(0) + for i in range(len(frame_wise_labels)): + if frame_wise_labels[i] != last_label: + if frame_wise_labels[i] not in bg_class: + labels.append(frame_wise_labels[i]) + starts.append(i) + if last_label not in bg_class: + ends.append(i) + last_label = frame_wise_labels[i] + if last_label not in bg_class: + ends.append(i) + return labels, starts, ends + + def levenstein(self, p, y, norm=False): + m_row = len(p) + n_col = len(y) + D = np.zeros([m_row + 1, n_col + 1], np.float64) + for i in range(m_row + 1): + D[i, 0] = i + for i in range(n_col + 1): + D[0, i] = i + + for j in range(1, n_col + 1): + for i in range(1, m_row + 1): + if y[j - 1] == p[i - 1]: + D[i, j] = D[i - 1, j - 1] + else: + D[i, j] = min(D[i - 1, j] + 1, D[i, j - 1] + 1, + D[i - 1, j - 1] + 1) + + if norm: + score = (1 - D[-1, -1] / max(m_row, n_col)) * 100 + else: + score = D[-1, -1] + + return score diff --git a/mmaction/models/action_segmentors/__init__.py b/mmaction/models/action_segmentors/__init__.py new file mode 100644 index 0000000000..eeb81b53e9 --- /dev/null +++ b/mmaction/models/action_segmentors/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .asformer import ASFormer + +__all__ = [ + 'ASFormer', +] diff --git a/mmaction/models/action_segmentors/asformer.py b/mmaction/models/action_segmentors/asformer.py new file mode 100644 index 0000000000..aefb62a96d --- /dev/null +++ b/mmaction/models/action_segmentors/asformer.py @@ -0,0 +1,610 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import math + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmengine.model import BaseModel + +from mmaction.registry import MODELS + + +@MODELS.register_module() +class ASFormer(BaseModel): + """Boundary Matching Network for temporal action proposal generation.""" + + def __init__(self, + num_decoders, + num_layers, + r1, + r2, + num_f_maps, + input_dim, + num_classes, + channel_masking_rate, + sample_rate, + loss_cls=dict(type='BMNLoss')): + super().__init__() + self.model = MyTransformer(3, num_layers, r1, r2, num_f_maps, + input_dim, num_classes, + channel_masking_rate) + print('Model Size: ', sum(p.numel() for p in self.model.parameters())) + self.num_classes = num_classes + + def init_weights(self) -> None: + """Initiate the parameters from scratch.""" + pass + + def forward(self, inputs, data_samples, mode, **kwargs): + """The unified entry for a forward process in both training and test. + + The method should accept three modes: + + - ``tensor``: Forward the whole network and return tensor or tuple of + tensor without any post-processing, same as a common nn.Module. + - ``predict``: Forward and return the predictions, which are fully + processed to a list of :obj:`ActionDataSample`. + - ``loss``: Forward and return a dict of losses according to the given + inputs and data samples. + + Note that this method doesn't handle neither back propagation nor + optimizer updating, which are done in the :meth:`train_step`. + + Args: + inputs (Tensor): The input tensor with shape + (N, C, ...) in general. + data_samples (List[:obj:`ActionDataSample`], optional): The + annotation data of every samples. Defaults to None. + mode (str): Return what kind of value. Defaults to ``tensor``. + + Returns: + The return type depends on ``mode``. + + - If ``mode="tensor"``, return a tensor or a tuple of tensor. + - If ``mode="predict"``, return a list of ``ActionDataSample``. + - If ``mode="loss"``, return a dict of tensor. + """ + input = torch.zeros( + len(inputs), + np.shape(inputs[0])[0], + max(tensor.size(1) for tensor in inputs), + dtype=torch.float) + for i in range(len(inputs)): + input[i, :, :np.shape(inputs[i])[1]] = inputs[i] + + if mode == 'tensor': + return self._forward(inputs, **kwargs) + if mode == 'predict': + return self.predict(input, data_samples, **kwargs) + elif mode == 'loss': + return self.loss(inputs, data_samples, **kwargs) + else: + raise RuntimeError(f'Invalid mode "{mode}". ' + 'Only supports loss, predict and tensor mode') + + def loss(self, batch_inputs, batch_data_samples, **kwargs): + """Calculate losses from a batch of inputs and data samples. + + Args: + batch_inputs (Tensor): Raw Inputs of the recognizer. + These should usually be mean centered and std scaled. + batch_data_samples (List[:obj:`ActionDataSample`]): The batch + data samples. It usually includes information such + as ``gt_labels``. + + Returns: + dict: A dictionary of loss components. + """ + gt_bbox = [ + sample.gt_instances['gt_bbox'] for sample in batch_data_samples + ] + label_confidence, label_start, label_end = self.generate_labels( + gt_bbox) + + device = batch_inputs.device + label_confidence = label_confidence.to(device) + label_start = label_start.to(device) + label_end = label_end.to(device) + + confidence_map, start, end = self._forward(batch_inputs) + + loss = self.loss_cls(confidence_map, start, end, label_confidence, + label_start, label_end, self.bm_mask) + loss_dict = dict(loss=loss[0]) + return loss_dict + + def predict(self, batch_inputs, batch_data_samples, **kwargs): + """Define the computation performed at every call when testing.""" + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + actions_dict = batch_data_samples[0].actions_dict + batch_target_tensor = torch.ones( + len(batch_inputs), + max(tensor.size(1) for tensor in batch_inputs), + dtype=torch.long) * (-100) + batch_target = [ + data_sample.classes for data_sample in batch_data_samples + ] + mask = torch.zeros( + len(batch_inputs), + self.num_classes, + max(tensor.size(1) for tensor in batch_inputs), + dtype=torch.float) + for i in range(len(batch_inputs)): + batch_target_tensor[i, :np.shape(batch_data_samples[i].classes + )[0]] = torch.from_numpy( + batch_data_samples[i].classes) + mask[i, :, :np. + shape(batch_data_samples[i].classes)[0]] = torch.ones( + self.num_classes, + np.shape(batch_data_samples[i].classes)[0]) + batch_target_tensor = batch_target_tensor.to(device) + mask = mask.to(device) + batch_inputs = batch_inputs.to(device) + predictions = self.model(batch_inputs, mask) + for i in range(len(predictions)): + confidence, predicted = torch.max( + F.softmax(predictions[i], dim=1).data, 1) + confidence, predicted = confidence.squeeze(), predicted.squeeze() + confidence, predicted = confidence.squeeze(), predicted.squeeze() + recognition = [] + ground = [ + batch_data_samples[0].index2label[idx] for idx in batch_target[0] + ] + for i in range(len(predicted)): + recognition = np.concatenate((recognition, [ + list(actions_dict.keys())[list(actions_dict.values()).index( + predicted[i].item())] + ])) + torch.save(self.model, + 'D:/2013jts_hero/video_seg/ASFormer-main/epoch-1200.pth') + output = [dict(ground=ground, recognition=recognition)] + return output + + def _forward(self, x): + """Define the computation performed at every call. + + Args: + x (torch.Tensor): The input data. + Returns: + torch.Tensor: The output of the module. + """ + print(x.shape) + + return x.shape + + +def exponential_descrease(idx_decoder, p=3): # 从e的0到e的-6 + return math.exp(-p * idx_decoder) + + +class AttentionHelper(nn.Module): + + def __init__(self): + super(AttentionHelper, self).__init__() + self.softmax = nn.Softmax(dim=-1) + + def scalar_dot_att(self, proj_query, proj_key, proj_val, padding_mask): + """scalar dot attention. + + :param proj_query: shape of (B, C, L) + => (Batch_Size, Feature_Dimension, Length) + :param proj_key: shape of (B, C, L) + :param proj_val: shape of (B, C, L) + :param padding_mask: shape of (B, C, L) + :return: attention value of shape (B, C, L) + """ + m, c1, l1 = proj_query.shape + m, c2, l2 = proj_key.shape + + assert c1 == c2 + + energy = torch.bmm(proj_query.permute(0, 2, 1), proj_key) + attention = energy / np.sqrt(c1) + attention = attention + torch.log(padding_mask + 1e-6) + attention = self.softmax(attention) + attention = attention * padding_mask + attention = attention.permute(0, 2, 1) + out = torch.bmm(proj_val, attention) + return out, attention + + +class AttLayer(nn.Module): + + def __init__(self, q_dim, k_dim, v_dim, r1, r2, r3, bl, stage, att_type): + super(AttLayer, self).__init__() + self.query_conv = nn.Conv1d( + in_channels=q_dim, out_channels=q_dim // r1, kernel_size=1) + self.key_conv = nn.Conv1d( + in_channels=k_dim, out_channels=k_dim // r2, kernel_size=1) + self.value_conv = nn.Conv1d( + in_channels=v_dim, out_channels=v_dim // r3, kernel_size=1) + + self.conv_out = nn.Conv1d( + in_channels=v_dim // r3, out_channels=v_dim, kernel_size=1) + + self.bl = bl + self.stage = stage + self.att_type = att_type + assert self.att_type in ['normal_att', 'block_att', 'sliding_att'] + assert self.stage in ['encoder', 'decoder'] + + self.att_helper = AttentionHelper() + self.window_mask = self.construct_window_mask() + + def construct_window_mask(self): + """construct window mask of shape (1, l, l + l//2 + l//2), used for + sliding window self attention.""" + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + window_mask = torch.zeros((1, self.bl, self.bl + 2 * (self.bl // 2))) + for i in range(self.bl): + window_mask[:, :, i:i + self.bl] = 1 + return window_mask.to(device) + + def forward(self, x1, x2, mask): + query = self.query_conv(x1) + key = self.key_conv(x1) + + if self.stage == 'decoder': + assert x2 is not None + value = self.value_conv(x2) + else: + value = self.value_conv(x1) + + if self.att_type == 'normal_att': + return self._normal_self_att(query, key, value, mask) + elif self.att_type == 'block_att': + return self._block_wise_self_att(query, key, value, mask) + elif self.att_type == 'sliding_att': + return self._sliding_window_self_att(query, key, value, mask) + + def _normal_self_att(self, q, k, v, mask): + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + m_batchsize, c1, L = q.size() + _, c2, L = k.size() + _, c3, L = v.size() + padding_mask = torch.ones( + (m_batchsize, 1, L)).to(device) * mask[:, 0:1, :] + output, attentions = self.att_helper.scalar_dot_att( + q, k, v, padding_mask) + output = self.conv_out(F.relu(output)) + output = output[:, :, 0:L] + return output * mask[:, 0:1, :] + + def _block_wise_self_att(self, q, k, v, mask): + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + m_batchsize, c1, L = q.size() + _, c2, L = k.size() + _, c3, L = v.size() + + nb = L // self.bl + if L % self.bl != 0: + q = torch.cat([ + q, + torch.zeros( + (m_batchsize, c1, self.bl - L % self.bl)).to(device) + ], + dim=-1) + k = torch.cat([ + k, + torch.zeros( + (m_batchsize, c2, self.bl - L % self.bl)).to(device) + ], + dim=-1) + v = torch.cat([ + v, + torch.zeros( + (m_batchsize, c3, self.bl - L % self.bl)).to(device) + ], + dim=-1) + nb += 1 + + padding_mask = torch.cat([ + torch.ones((m_batchsize, 1, L)).to(device) * mask[:, 0:1, :], + torch.zeros((m_batchsize, 1, self.bl * nb - L)).to(device) + ], + dim=-1) + + q = q.reshape(m_batchsize, c1, nb, + self.bl).permute(0, 2, 1, + 3).reshape(m_batchsize * nb, c1, + self.bl) + padding_mask = padding_mask.reshape( + m_batchsize, 1, nb, + self.bl).permute(0, 2, 1, 3).reshape(m_batchsize * nb, 1, self.bl) + k = k.reshape(m_batchsize, c2, nb, + self.bl).permute(0, 2, 1, + 3).reshape(m_batchsize * nb, c2, + self.bl) + v = v.reshape(m_batchsize, c3, nb, + self.bl).permute(0, 2, 1, + 3).reshape(m_batchsize * nb, c3, + self.bl) + + output, attentions = self.att_helper.scalar_dot_att( + q, k, v, padding_mask) + output = self.conv_out(F.relu(output)) + + output = output.reshape(m_batchsize, nb, c3, self.bl).permute( + 0, 2, 1, 3).reshape(m_batchsize, c3, nb * self.bl) + output = output[:, :, 0:L] + return output * mask[:, 0:1, :] + + def _sliding_window_self_att(self, q, k, v, mask): + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + m_batchsize, c1, L = q.size() + _, c2, _ = k.size() + _, c3, _ = v.size() + nb = L // self.bl + if L % self.bl != 0: + q = torch.cat([ + q, + torch.zeros( + (m_batchsize, c1, self.bl - L % self.bl)).to(device) + ], + dim=-1) + k = torch.cat([ + k, + torch.zeros( + (m_batchsize, c2, self.bl - L % self.bl)).to(device) + ], + dim=-1) + v = torch.cat([ + v, + torch.zeros( + (m_batchsize, c3, self.bl - L % self.bl)).to(device) + ], + dim=-1) + nb += 1 + padding_mask = torch.cat([ + torch.ones((m_batchsize, 1, L)).to(device) * mask[:, 0:1, :], + torch.zeros((m_batchsize, 1, self.bl * nb - L)).to(device) + ], + dim=-1) + q = q.reshape(m_batchsize, c1, nb, + self.bl).permute(0, 2, 1, + 3).reshape(m_batchsize * nb, c1, + self.bl) + k = torch.cat([ + torch.zeros(m_batchsize, c2, self.bl // 2).to(device), k, + torch.zeros(m_batchsize, c2, self.bl // 2).to(device) + ], + dim=-1) + v = torch.cat([ + torch.zeros(m_batchsize, c3, self.bl // 2).to(device), v, + torch.zeros(m_batchsize, c3, self.bl // 2).to(device) + ], + dim=-1) + padding_mask = torch.cat([ + torch.zeros(m_batchsize, 1, self.bl // 2).to(device), padding_mask, + torch.zeros(m_batchsize, 1, self.bl // 2).to(device) + ], + dim=-1) + k = torch.cat([ + k[:, :, i * self.bl:(i + 1) * self.bl + (self.bl // 2) * 2] + for i in range(nb) + ], + dim=0) + v = torch.cat([ + v[:, :, i * self.bl:(i + 1) * self.bl + (self.bl // 2) * 2] + for i in range(nb) + ], + dim=0) + padding_mask = torch.cat([ + padding_mask[:, :, i * self.bl:(i + 1) * self.bl + + (self.bl // 2) * 2] for i in range(nb) + ], + dim=0) + final_mask = self.window_mask.repeat(m_batchsize * nb, 1, + 1) * padding_mask + + output, attention = self.att_helper.scalar_dot_att(q, k, v, final_mask) + output = self.conv_out(F.relu(output)) + + output = output.reshape(m_batchsize, nb, -1, self.bl).permute( + 0, 2, 1, 3).reshape(m_batchsize, -1, nb * self.bl) + output = output[:, :, 0:L] + return output * mask[:, 0:1, :] + + +class MultiHeadAttLayer(nn.Module): + + def __init__(self, q_dim, k_dim, v_dim, r1, r2, r3, bl, stage, att_type, + num_head): + super(MultiHeadAttLayer, self).__init__() + self.conv_out = nn.Conv1d(v_dim * num_head, v_dim, 1) + self.layers = nn.ModuleList([ + copy.deepcopy( + AttLayer(q_dim, k_dim, v_dim, r1, r2, r3, bl, stage, att_type)) + for i in range(num_head) + ]) + self.dropout = nn.Dropout(p=0.5) + + def forward(self, x1, x2, mask): + out = torch.cat([layer(x1, x2, mask) for layer in self.layers], dim=1) + out = self.conv_out(self.dropout(out)) + return out + + +class ConvFeedForward(nn.Module): + + def __init__(self, dilation, in_channels, out_channels): + super(ConvFeedForward, self).__init__() + self.layer = nn.Sequential( + nn.Conv1d( + in_channels, + out_channels, + 3, + padding=dilation, + dilation=dilation), nn.ReLU()) + + def forward(self, x): + return self.layer(x) + + +class FCFeedForward(nn.Module): + + def __init__(self, in_channels, out_channels): + super(FCFeedForward, self).__init__() + self.layer = nn.Sequential( + nn.Conv1d(in_channels, out_channels, 1), # conv1d equals fc + nn.ReLU(), + nn.Dropout(), + nn.Conv1d(out_channels, out_channels, 1)) + + def forward(self, x): + return self.layer(x) + + +class AttModule(nn.Module): + + def __init__(self, dilation, in_channels, out_channels, r1, r2, att_type, + stage, alpha): + super(AttModule, self).__init__() + self.feed_forward = ConvFeedForward(dilation, in_channels, + out_channels) + self.instance_norm = nn.InstanceNorm1d( + in_channels, track_running_stats=False) + self.att_layer = AttLayer( + in_channels, + in_channels, + out_channels, + r1, + r1, + r2, + dilation, + att_type=att_type, + stage=stage) + self.conv_1x1 = nn.Conv1d(out_channels, out_channels, 1) + self.dropout = nn.Dropout() + self.alpha = alpha + + def forward(self, x, f, mask): + out = self.feed_forward(x) + out = self.alpha * self.att_layer(self.instance_norm(out), f, + mask) + out + out = self.conv_1x1(out) + out = self.dropout(out) + return (x + out) * mask[:, 0:1, :] + + +class PositionalEncoding(nn.Module): + """Implement the PE function.""" + + def __init__(self, d_model, max_len=10000): + super(PositionalEncoding, self).__init__() + pe = torch.zeros(max_len, d_model) + position = torch.arange(0, max_len).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model)) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0).permute(0, 2, 1) + self.pe = nn.Parameter(pe, requires_grad=True) + + def forward(self, x): + return x + self.pe[:, :, 0:x.shape[2]] + + +class Encoder(nn.Module): + + def __init__(self, num_layers, r1, r2, num_f_maps, input_dim, num_classes, + channel_masking_rate, att_type, alpha): + super(Encoder, self).__init__() + self.conv_1x1 = nn.Conv1d(input_dim, num_f_maps, 1) + self.layers = nn.ModuleList([ + AttModule(2**i, num_f_maps, num_f_maps, r1, r2, att_type, + 'encoder', alpha) for i in # 2**i + range(num_layers) + ]) + + self.conv_out = nn.Conv1d(num_f_maps, num_classes, 1) + self.dropout = nn.Dropout2d(p=channel_masking_rate) + self.channel_masking_rate = channel_masking_rate + + def forward(self, x, mask): + ''' + :param x: (N, C, L) + :param mask: + :return: + ''' + + if self.channel_masking_rate > 0: + x = x.unsqueeze(2) + x = self.dropout(x) + x = x.squeeze(2) + + feature = self.conv_1x1(x) + for layer in self.layers: + feature = layer(feature, None, mask) + + out = self.conv_out(feature) * mask[:, 0:1, :] + + return out, feature + + +class Decoder(nn.Module): + + def __init__(self, num_layers, r1, r2, num_f_maps, input_dim, num_classes, + att_type, alpha): + super(Decoder, self).__init__() + self.conv_1x1 = nn.Conv1d(input_dim, num_f_maps, 1) + self.layers = nn.ModuleList([ + AttModule(2**i, num_f_maps, num_f_maps, r1, r2, att_type, + 'decoder', alpha) for i in # 2 ** i + range(num_layers) + ]) + self.conv_out = nn.Conv1d(num_f_maps, num_classes, 1) + + def forward(self, x, fencoder, mask): + + feature = self.conv_1x1(x) + for layer in self.layers: + feature = layer(feature, fencoder, mask) + + out = self.conv_out(feature) * mask[:, 0:1, :] + + return out, feature + + +class MyTransformer(nn.Module): + + def __init__(self, num_decoders, num_layers, r1, r2, num_f_maps, input_dim, + num_classes, channel_masking_rate): + super(MyTransformer, self).__init__() + self.encoder = Encoder( + num_layers, + r1, + r2, + num_f_maps, + input_dim, + num_classes, + channel_masking_rate, + att_type='sliding_att', + alpha=1) + self.decoders = nn.ModuleList([ + copy.deepcopy( + Decoder( + num_layers, + r1, + r2, + num_f_maps, + num_classes, + num_classes, + att_type='sliding_att', + alpha=exponential_descrease(s))) + for s in range(num_decoders) + ]) + + def forward(self, x, mask): + out, feature = self.encoder(x, mask) + outputs = out.unsqueeze(0) + + for decoder in self.decoders: + out, feature = decoder( + F.softmax(out, dim=1) * mask[:, 0:1, :], + feature * mask[:, 0:1, :], mask) + outputs = torch.cat((outputs, out.unsqueeze(0)), dim=0) + + return outputs