From cc3b74b5e874255a74f0357ad0c65972f2cd7b4d Mon Sep 17 00:00:00 2001 From: Mashiro <57566630+HAOCHENYE@users.noreply.github.com> Date: Sat, 2 Nov 2024 22:23:51 +0800 Subject: [PATCH] [Fix] Fix lint (#1598) * [Fix] Fix lint * [Fix] Fix lint * Update mmengine/dist/utils.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> --------- Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> --- .github/workflows/lint.yml | 4 +-- .github/workflows/pr_stage_test.yml | 4 +++ .pre-commit-config-zh-cn.yaml | 13 +++++--- .pre-commit-config.yaml | 17 +++++----- mmengine/_strategy/base.py | 2 +- mmengine/_strategy/colossalai.py | 2 +- mmengine/_strategy/distributed.py | 2 +- mmengine/config/config.py | 31 ++++++++++--------- mmengine/dist/utils.py | 2 +- mmengine/evaluator/metric.py | 6 ++-- mmengine/hooks/profiler_hook.py | 2 +- mmengine/logging/history_buffer.py | 2 +- .../model/base_model/data_preprocessor.py | 14 ++++----- mmengine/model/weight_init.py | 14 ++++----- .../wrappers/fully_sharded_distributed.py | 7 +++-- .../optim/optimizer/apex_optimizer_wrapper.py | 3 +- .../optim/optimizer/default_constructor.py | 4 +-- .../optim/optimizer/optimizer_wrapper_dict.py | 3 +- mmengine/optim/scheduler/lr_scheduler.py | 14 ++++----- mmengine/optim/scheduler/param_scheduler.py | 22 ++++++------- mmengine/registry/default_scope.py | 2 +- mmengine/registry/registry.py | 2 +- mmengine/registry/utils.py | 2 +- mmengine/runner/_flexible_runner.py | 6 ++-- mmengine/runner/checkpoint.py | 16 +++++----- mmengine/runner/runner.py | 6 ++-- mmengine/structures/base_data_element.py | 2 +- mmengine/structures/instance_data.py | 2 +- mmengine/utils/dl_utils/hub.py | 1 + mmengine/utils/dl_utils/time_counter.py | 4 +-- mmengine/utils/misc.py | 2 +- mmengine/utils/package_utils.py | 4 +-- mmengine/utils/progressbar.py | 8 ++--- mmengine/utils/progressbar_rich.py | 4 +-- mmengine/visualization/vis_backend.py | 12 +++---- mmengine/visualization/visualizer.py | 4 +-- tests/test_config/test_config.py | 4 +-- .../test_backends/test_petrel_backend.py | 8 ++--- tests/test_model/test_averaged_model.py | 3 +- .../test_scheduler/test_lr_scheduler.py | 2 +- .../test_scheduler/test_momentum_scheduler.py | 2 +- .../test_scheduler/test_param_scheduler.py | 2 +- tests/test_runner/test_checkpoint.py | 2 +- tests/test_runner/test_runner.py | 2 +- tests/test_structures/test_data_element.py | 8 ++--- tests/test_visualizer/test_visualizer.py | 2 +- 46 files changed, 146 insertions(+), 134 deletions(-) diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 075baad95c..0e8d92b9d6 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -11,10 +11,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 - - name: Set up Python 3.7 + - name: Set up Python 3.10.15 uses: actions/setup-python@v2 with: - python-version: 3.7 + python-version: '3.10.15' - name: Install pre-commit hook run: | pip install pre-commit diff --git a/.github/workflows/pr_stage_test.yml b/.github/workflows/pr_stage_test.yml index b2576cfd57..e072c752cd 100644 --- a/.github/workflows/pr_stage_test.yml +++ b/.github/workflows/pr_stage_test.yml @@ -1,5 +1,9 @@ name: pr_stage_test +env: + ACTIONS_ALLOW_USE_UNSECURE_NODE_VERSION: true + + on: pull_request: paths-ignore: diff --git a/.pre-commit-config-zh-cn.yaml b/.pre-commit-config-zh-cn.yaml index 7395970e6a..02e009fd74 100644 --- a/.pre-commit-config-zh-cn.yaml +++ b/.pre-commit-config-zh-cn.yaml @@ -1,7 +1,11 @@ exclude: ^tests/data/ repos: - - repo: https://gitee.com/openmmlab/mirrors-flake8 - rev: 5.0.4 + - repo: https://github.com/pre-commit/pre-commit + rev: v4.0.0 + hooks: + - id: validate_manifest + - repo: https://github.com/PyCQA/flake8 + rev: 7.1.1 hooks: - id: flake8 - repo: https://gitee.com/openmmlab/mirrors-isort @@ -13,7 +17,7 @@ repos: hooks: - id: yapf - repo: https://gitee.com/openmmlab/mirrors-pre-commit-hooks - rev: v4.3.0 + rev: v5.0.0 hooks: - id: trailing-whitespace - id: check-yaml @@ -55,7 +59,7 @@ repos: args: ["mmengine", "tests"] - id: remove-improper-eol-in-cn-docs - repo: https://gitee.com/openmmlab/mirrors-mypy - rev: v0.812 + rev: v1.2.0 hooks: - id: mypy exclude: |- @@ -63,3 +67,4 @@ repos: ^examples | ^docs ) + additional_dependencies: ["types-setuptools", "types-requests", "types-PyYAML"] diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1eb665c803..c8edd013c6 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,7 +1,11 @@ exclude: ^tests/data/ repos: + - repo: https://github.com/pre-commit/pre-commit + rev: v4.0.0 + hooks: + - id: validate_manifest - repo: https://github.com/PyCQA/flake8 - rev: 5.0.4 + rev: 7.1.1 hooks: - id: flake8 - repo: https://github.com/PyCQA/isort @@ -13,7 +17,7 @@ repos: hooks: - id: yapf - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.3.0 + rev: v5.0.0 hooks: - id: trailing-whitespace - id: check-yaml @@ -34,12 +38,8 @@ repos: - mdformat-openmmlab - mdformat_frontmatter - linkify-it-py - - repo: https://github.com/codespell-project/codespell - rev: v2.2.1 - hooks: - - id: codespell - repo: https://github.com/myint/docformatter - rev: v1.3.1 + rev: 06907d0 hooks: - id: docformatter args: ["--in-place", "--wrap-descriptions", "79"] @@ -55,7 +55,7 @@ repos: args: ["mmengine", "tests"] - id: remove-improper-eol-in-cn-docs - repo: https://github.com/pre-commit/mirrors-mypy - rev: v0.812 + rev: v1.2.0 hooks: - id: mypy exclude: |- @@ -63,3 +63,4 @@ repos: ^examples | ^docs ) + additional_dependencies: ["types-setuptools", "types-requests", "types-PyYAML"] diff --git a/mmengine/_strategy/base.py b/mmengine/_strategy/base.py index 5df3a79c92..a713da9a70 100644 --- a/mmengine/_strategy/base.py +++ b/mmengine/_strategy/base.py @@ -499,7 +499,7 @@ def build_optim_wrapper( '"type" and "constructor" are not in ' f'optimizer, but got {name}={optim}') optim_wrappers[name] = optim - return OptimWrapperDict(**optim_wrappers) + return OptimWrapperDict(**optim_wrappers) # type: ignore else: raise TypeError('optimizer wrapper should be an OptimWrapper ' f'object or dict, but got {optim_wrapper}') diff --git a/mmengine/_strategy/colossalai.py b/mmengine/_strategy/colossalai.py index cfbb925c67..13d9f38fc3 100644 --- a/mmengine/_strategy/colossalai.py +++ b/mmengine/_strategy/colossalai.py @@ -361,7 +361,7 @@ def resume( map_location: Union[str, Callable] = 'default', callback: Optional[Callable] = None, ) -> dict: - """override this method since colossalai resume optimizer from filename + """Override this method since colossalai resume optimizer from filename directly.""" self.logger.info(f'Resume checkpoint from {filename}') diff --git a/mmengine/_strategy/distributed.py b/mmengine/_strategy/distributed.py index 6c969b85b1..dbe17d5aeb 100644 --- a/mmengine/_strategy/distributed.py +++ b/mmengine/_strategy/distributed.py @@ -53,7 +53,7 @@ def _setup_distributed( # type: ignore init_dist(launcher, backend, **kwargs) def convert_model(self, model: nn.Module) -> nn.Module: - """convert all ``BatchNorm`` layers in the model to ``SyncBatchNorm`` + """Convert all ``BatchNorm`` layers in the model to ``SyncBatchNorm`` (SyncBN) or ``mmcv.ops.sync_bn.SyncBatchNorm`` (MMSyncBN) layers. Args: diff --git a/mmengine/config/config.py b/mmengine/config/config.py index f85795066a..36f92f0b3a 100644 --- a/mmengine/config/config.py +++ b/mmengine/config/config.py @@ -393,7 +393,7 @@ class Config: def __init__( self, - cfg_dict: dict = None, + cfg_dict: Optional[dict] = None, cfg_text: Optional[str] = None, filename: Optional[Union[str, Path]] = None, env_variables: Optional[dict] = None, @@ -1227,7 +1227,8 @@ def is_base_line(c): if base_code is not None: base_code = ast.Expression( # type: ignore body=base_code.value) # type: ignore - base_files = eval(compile(base_code, '', mode='eval')) + base_files = eval(compile(base_code, '', + mode='eval')) # type: ignore else: base_files = [] elif file_format in ('.yml', '.yaml', '.json'): @@ -1288,7 +1289,7 @@ def _get_cfg_path(cfg_path: str, def _merge_a_into_b(a: dict, b: dict, allow_list_keys: bool = False) -> dict: - """merge dict ``a`` into dict ``b`` (non-inplace). + """Merge dict ``a`` into dict ``b`` (non-inplace). Values in ``a`` will overwrite ``b``. ``b`` is copied first to avoid in-place modifications. @@ -1358,22 +1359,22 @@ def auto_argparser(description=None): @property def filename(self) -> str: - """get file name of config.""" + """Get file name of config.""" return self._filename @property def text(self) -> str: - """get config text.""" + """Get config text.""" return self._text @property def env_variables(self) -> dict: - """get used environment variables.""" + """Get used environment variables.""" return self._env_variables @property def pretty_text(self) -> str: - """get formatted python config text.""" + """Get formatted python config text.""" indent = 4 @@ -1727,17 +1728,17 @@ def to_dict(self, keep_imported: bool = False): class DictAction(Action): - """ - argparse action to split an argument into KEY=VALUE form - on the first = and append to a dictionary. List options can - be passed as comma separated values, i.e 'KEY=V1,V2,V3', or with explicit - brackets, i.e. 'KEY=[V1,V2,V3]'. It also support nested brackets to build - list/tuple values. e.g. 'KEY=[(V1,V2),(V3,V4)]' + """Argparse action to split an argument into KEY=VALUE form on the first = + and append to a dictionary. + + List options can be passed as comma separated values, i.e 'KEY=V1,V2,V3', + or with explicit brackets, i.e. 'KEY=[V1,V2,V3]'. It also support nested + brackets to build list/tuple values. e.g. 'KEY=[(V1,V2),(V3,V4)]' """ @staticmethod def _parse_int_float_bool(val: str) -> Union[int, float, bool, Any]: - """parse int/float/bool value in the string.""" + """Parse int/float/bool value in the string.""" try: return int(val) except ValueError: @@ -1822,7 +1823,7 @@ def __call__(self, parser: ArgumentParser, namespace: Namespace, values: Union[str, Sequence[Any], None], - option_string: str = None): + option_string: str = None): # type: ignore """Parse Variables in string and add them into argparser. Args: diff --git a/mmengine/dist/utils.py b/mmengine/dist/utils.py index 3c136973bb..5d32cec36b 100644 --- a/mmengine/dist/utils.py +++ b/mmengine/dist/utils.py @@ -563,7 +563,7 @@ def cast_data_device( Tensor or list or dict: ``data`` was casted to ``device``. """ if out is not None: - if type(data) != type(out): + if type(data) is not type(out): raise TypeError( 'out should be the same type with data, but got data is ' f'{type(data)} and out is {type(data)}') diff --git a/mmengine/evaluator/metric.py b/mmengine/evaluator/metric.py index 6e6d40bee3..1292ce61ec 100644 --- a/mmengine/evaluator/metric.py +++ b/mmengine/evaluator/metric.py @@ -175,11 +175,11 @@ def __init__(self, self.out_file_path = out_file_path def process(self, data_batch: Any, predictions: Sequence[dict]) -> None: - """transfer tensors in predictions to CPU.""" + """Transfer tensors in predictions to CPU.""" self.results.extend(_to_cpu(predictions)) def compute_metrics(self, results: list) -> dict: - """dump the prediction results to a pickle file.""" + """Dump the prediction results to a pickle file.""" dump(results, self.out_file_path) print_log( f'Results has been saved to {self.out_file_path}.', @@ -188,7 +188,7 @@ def compute_metrics(self, results: list) -> dict: def _to_cpu(data: Any) -> Any: - """transfer all tensors and BaseDataElement to cpu.""" + """Transfer all tensors and BaseDataElement to cpu.""" if isinstance(data, (Tensor, BaseDataElement)): return data.to('cpu') elif isinstance(data, list): diff --git a/mmengine/hooks/profiler_hook.py b/mmengine/hooks/profiler_hook.py index 6339a5da92..dae84b85f5 100644 --- a/mmengine/hooks/profiler_hook.py +++ b/mmengine/hooks/profiler_hook.py @@ -233,7 +233,7 @@ def after_train_epoch(self, runner): self._export_chrome_trace(runner) def after_train_iter(self, runner, batch_idx, data_batch, outputs): - """profiler will call `step` method if it is not closed.""" + """Profiler will call `step` method if it is not closed.""" if not self._closed: self.profiler.step() if runner.iter == self.profile_times - 1 and not self.by_epoch: diff --git a/mmengine/logging/history_buffer.py b/mmengine/logging/history_buffer.py index 58effa8152..a50de22c65 100644 --- a/mmengine/logging/history_buffer.py +++ b/mmengine/logging/history_buffer.py @@ -58,7 +58,7 @@ def _set_default_statistics(self) -> None: self._statistics_methods.setdefault('mean', HistoryBuffer.mean) def update(self, log_val: Union[int, float], count: int = 1) -> None: - """update the log history. + """Update the log history. If the length of the buffer exceeds ``self._max_length``, the oldest element will be removed from the buffer. diff --git a/mmengine/model/base_model/data_preprocessor.py b/mmengine/model/base_model/data_preprocessor.py index a101855203..4d621851b0 100644 --- a/mmengine/model/base_model/data_preprocessor.py +++ b/mmengine/model/base_model/data_preprocessor.py @@ -253,17 +253,17 @@ def forward(self, data: dict, training: bool = False) -> Union[dict, list]: dict or list: Data in the same format as the model input. """ data = self.cast_data(data) # type: ignore - _batch_inputs = data['inputs'] + _batch_inputs = data['inputs'] # type: ignore # Process data with `pseudo_collate`. if is_seq_of(_batch_inputs, torch.Tensor): batch_inputs = [] for _batch_input in _batch_inputs: # channel transform if self._channel_conversion: - _batch_input = _batch_input[[2, 1, 0], ...] + _batch_input = _batch_input[[2, 1, 0], ...] # type: ignore # Convert to float after channel conversion to ensure # efficiency - _batch_input = _batch_input.float() + _batch_input = _batch_input.float() # type: ignore # Normalization. if self._enable_normalize: if self.mean.shape[0] == 3: @@ -302,7 +302,7 @@ def forward(self, data: dict, training: bool = False) -> Union[dict, list]: else: raise TypeError('Output of `cast_data` should be a dict of ' 'list/tuple with inputs and data_samples, ' - f'but got {type(data)}: {data}') - data['inputs'] = batch_inputs - data.setdefault('data_samples', None) - return data + f'but got {type(data)}: {data}') # type: ignore + data['inputs'] = batch_inputs # type: ignore + data.setdefault('data_samples', None) # type: ignore + return data # type: ignore diff --git a/mmengine/model/weight_init.py b/mmengine/model/weight_init.py index a2e0b9a7a5..b6d0186ed7 100644 --- a/mmengine/model/weight_init.py +++ b/mmengine/model/weight_init.py @@ -119,7 +119,7 @@ def caffe2_xavier_init(module, bias=0): def bias_init_with_prob(prior_prob): - """initialize conv/fc bias value according to a given probability value.""" + """Initialize conv/fc bias value according to a given probability value.""" bias_init = float(-np.log((1 - prior_prob) / prior_prob)) return bias_init @@ -662,12 +662,12 @@ def trunc_normal_(tensor: Tensor, std: float = 1., a: float = -2., b: float = 2.) -> Tensor: - r"""Fills the input Tensor with values drawn from a truncated - normal distribution. The values are effectively drawn from the - normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` - with values outside :math:`[a, b]` redrawn until they are within - the bounds. The method used for generating the random values works - best when :math:`a \leq \text{mean} \leq b`. + r"""Fills the input Tensor with values drawn from a truncated normal + distribution. The values are effectively drawn from the normal distribution + :math:`\mathcal{N}(\text{mean}, \text{std}^2)` with values outside + :math:`[a, b]` redrawn until they are within the bounds. The method used + for generating the random values works best when :math:`a \leq \text{mean} + \leq b`. Modified from https://github.com/pytorch/pytorch/blob/master/torch/nn/init.py diff --git a/mmengine/model/wrappers/fully_sharded_distributed.py b/mmengine/model/wrappers/fully_sharded_distributed.py index d6b145ecf9..df128597b1 100644 --- a/mmengine/model/wrappers/fully_sharded_distributed.py +++ b/mmengine/model/wrappers/fully_sharded_distributed.py @@ -127,7 +127,8 @@ def __init__( auto_wrap_policy: Union[str, Callable, None] = None, backward_prefetch: Union[str, BackwardPrefetch, None] = None, mixed_precision: Union[dict, MixedPrecision, None] = None, - param_init_fn: Union[str, Callable[[nn.Module], None]] = None, + param_init_fn: Union[str, Callable[ + [nn.Module], None]] = None, # type: ignore # noqa: E501 use_orig_params: bool = True, **kwargs, ): @@ -362,7 +363,7 @@ def optim_state_dict( optim: torch.optim.Optimizer, group: Optional[dist.ProcessGroup] = None, ) -> Dict[str, Any]: - """copied from pytorch 2.0.1 which has fixed some bugs.""" + """Copied from pytorch 2.0.1 which has fixed some bugs.""" state_dict_settings = FullyShardedDataParallel.get_state_dict_type( model) return FullyShardedDataParallel._optim_state_dict_impl( @@ -384,7 +385,7 @@ def set_state_dict_type( state_dict_config: Optional[StateDictConfig] = None, optim_state_dict_config: Optional[OptimStateDictConfig] = None, ) -> StateDictSettings: - """copied from pytorch 2.0.1 which has fixed some bugs.""" + """Copied from pytorch 2.0.1 which has fixed some bugs.""" import torch.distributed.fsdp._traversal_utils as traversal_utils _state_dict_type_to_config = { StateDictType.FULL_STATE_DICT: FullStateDictConfig, diff --git a/mmengine/optim/optimizer/apex_optimizer_wrapper.py b/mmengine/optim/optimizer/apex_optimizer_wrapper.py index 5f2f6f4a1b..a2e6190460 100644 --- a/mmengine/optim/optimizer/apex_optimizer_wrapper.py +++ b/mmengine/optim/optimizer/apex_optimizer_wrapper.py @@ -123,8 +123,7 @@ def backward(self, loss: torch.Tensor, **kwargs) -> None: self._inner_count += 1 def state_dict(self) -> dict: - """Get the state dictionary of :attr:`optimizer` and - :attr:`apex_amp`. + """Get the state dictionary of :attr:`optimizer` and :attr:`apex_amp`. Based on the state dictionary of the optimizer, the returned state dictionary will add a key named "apex_amp". diff --git a/mmengine/optim/optimizer/default_constructor.py b/mmengine/optim/optimizer/default_constructor.py index ec223a7967..b623a3e70e 100644 --- a/mmengine/optim/optimizer/default_constructor.py +++ b/mmengine/optim/optimizer/default_constructor.py @@ -131,7 +131,7 @@ def __init__(self, self._validate_cfg() def _validate_cfg(self) -> None: - """verify the correctness of the config.""" + """Verify the correctness of the config.""" if not isinstance(self.paramwise_cfg, dict): raise TypeError('paramwise_cfg should be None or a dict, ' f'but got {type(self.paramwise_cfg)}') @@ -155,7 +155,7 @@ def _validate_cfg(self) -> None: raise ValueError('base_wd should not be None') def _is_in(self, param_group: dict, param_group_list: list) -> bool: - """check whether the `param_group` is in the`param_group_list`""" + """Check whether the `param_group` is in the`param_group_list`""" assert is_list_of(param_group_list, dict) param = set(param_group['params']) param_set = set() diff --git a/mmengine/optim/optimizer/optimizer_wrapper_dict.py b/mmengine/optim/optimizer/optimizer_wrapper_dict.py index a18fd99cae..efa7705c9e 100644 --- a/mmengine/optim/optimizer/optimizer_wrapper_dict.py +++ b/mmengine/optim/optimizer/optimizer_wrapper_dict.py @@ -161,8 +161,7 @@ def load_state_dict(self, state_dict: dict) -> None: self.optim_wrappers[name].load_state_dict(_state_dict) def items(self) -> Iterator[Tuple[str, OptimWrapper]]: - """A generator to get the name and corresponding - :obj:`OptimWrapper`""" + """A generator to get the name and corresponding :obj:`OptimWrapper`""" yield from self.optim_wrappers.items() def values(self) -> Iterator[OptimWrapper]: diff --git a/mmengine/optim/scheduler/lr_scheduler.py b/mmengine/optim/scheduler/lr_scheduler.py index b12c60d0cf..13bc61d542 100644 --- a/mmengine/optim/scheduler/lr_scheduler.py +++ b/mmengine/optim/scheduler/lr_scheduler.py @@ -223,13 +223,13 @@ class PolyLR(LRSchedulerMixin, PolyParamScheduler): @PARAM_SCHEDULERS.register_module() class OneCycleLR(LRSchedulerMixin, OneCycleParamScheduler): - r"""Sets the learning rate of each parameter group according to the - 1cycle learning rate policy. The 1cycle policy anneals the learning - rate from an initial learning rate to some maximum learning rate and then - from that maximum learning rate to some minimum learning rate much lower - than the initial learning rate. - This policy was initially described in the paper `Super-Convergence: - Very Fast Training of Neural Networks Using Large Learning Rates`_. + r"""Sets the learning rate of each parameter group according to the 1cycle + learning rate policy. The 1cycle policy anneals the learning rate from an + initial learning rate to some maximum learning rate and then from that + maximum learning rate to some minimum learning rate much lower than the + initial learning rate. This policy was initially described in the paper + `Super-Convergence: Very Fast Training of Neural Networks Using Large + Learning Rates`_. The 1cycle learning rate policy changes the learning rate after every batch. `step` should be called after a batch has been used for training. diff --git a/mmengine/optim/scheduler/param_scheduler.py b/mmengine/optim/scheduler/param_scheduler.py index af89ccaea7..2dcb1af072 100644 --- a/mmengine/optim/scheduler/param_scheduler.py +++ b/mmengine/optim/scheduler/param_scheduler.py @@ -565,9 +565,9 @@ def _get_value(self): @PARAM_SCHEDULERS.register_module() class CosineAnnealingParamScheduler(_ParamScheduler): - r"""Set the parameter value of each parameter group using a cosine - annealing schedule, where :math:`\eta_{max}` is set to the initial value - and :math:`T_{cur}` is the number of epochs since the last restart in SGDR: + r"""Set the parameter value of each parameter group using a cosine annealing + schedule, where :math:`\eta_{max}` is set to the initial value and + :math:`T_{cur}` is the number of epochs since the last restart in SGDR: .. math:: \begin{aligned} @@ -617,7 +617,7 @@ class CosineAnnealingParamScheduler(_ParamScheduler): .. _SGDR\: Stochastic Gradient Descent with Warm Restarts: https://arxiv.org/abs/1608.03983 - """ + """ # noqa: E501 def __init__(self, optimizer: Union[Optimizer, BaseOptimWrapper], @@ -890,13 +890,13 @@ def _get_value(self): @PARAM_SCHEDULERS.register_module() class OneCycleParamScheduler(_ParamScheduler): - r"""Sets the parameters of each parameter group according to the - 1cycle learning rate policy. The 1cycle policy anneals the learning - rate from an initial learning rate to some maximum learning rate and then - from that maximum learning rate to some minimum learning rate much lower - than the initial learning rate. - This policy was initially described in the paper `Super-Convergence: - Very Fast Training of Neural Networks Using Large Learning Rates`_. + r"""Sets the parameters of each parameter group according to the 1cycle + learning rate policy. The 1cycle policy anneals the learning rate from an + initial learning rate to some maximum learning rate and then from that + maximum learning rate to some minimum learning rate much lower than the + initial learning rate. This policy was initially described in the paper + `Super-Convergence: Very Fast Training of Neural Networks Using Large + Learning Rates`_. The 1cycle learning rate policy changes the learning rate after every batch. `step` should be called after a batch has been used for training. diff --git a/mmengine/registry/default_scope.py b/mmengine/registry/default_scope.py index c9f1afcaba..f1347689e0 100644 --- a/mmengine/registry/default_scope.py +++ b/mmengine/registry/default_scope.py @@ -81,7 +81,7 @@ def get_current_instance(cls) -> Optional['DefaultScope']: @classmethod @contextmanager def overwrite_default_scope(cls, scope_name: Optional[str]) -> Generator: - """overwrite the current default scope with `scope_name`""" + """Overwrite the current default scope with `scope_name`""" if scope_name is None: yield else: diff --git a/mmengine/registry/registry.py b/mmengine/registry/registry.py index 31fd44d827..e7d8962be4 100644 --- a/mmengine/registry/registry.py +++ b/mmengine/registry/registry.py @@ -332,7 +332,7 @@ def _get_root_registry(self) -> 'Registry': return root def import_from_location(self) -> None: - """import modules from the pre-defined locations in self._location.""" + """Import modules from the pre-defined locations in self._location.""" if not self._imported: # Avoid circular import from ..logging import print_log diff --git a/mmengine/registry/utils.py b/mmengine/registry/utils.py index 568970bbf1..2737e879a7 100644 --- a/mmengine/registry/utils.py +++ b/mmengine/registry/utils.py @@ -109,7 +109,7 @@ def init_default_scope(scope: str) -> None: if current_scope.scope_name != scope: # type: ignore print_log( 'The current default scope ' # type: ignore - f'"{current_scope.scope_name}" is not "{scope}", ' + f'"{current_scope.scope_name}" is not "{scope}", ' # type: ignore '`init_default_scope` will force set the current' f'default scope to "{scope}".', logger='current', diff --git a/mmengine/runner/_flexible_runner.py b/mmengine/runner/_flexible_runner.py index 6d727fb4d5..5160a5cfb0 100644 --- a/mmengine/runner/_flexible_runner.py +++ b/mmengine/runner/_flexible_runner.py @@ -540,7 +540,7 @@ def timestamp(self): @property def hooks(self): - """list[:obj:`Hook`]: A list of registered hooks.""" + """List[:obj:`Hook`]: A list of registered hooks.""" return self._hooks @property @@ -1117,7 +1117,7 @@ def get_hooks_info(self) -> str: return '\n'.join(stage_hook_infos) def load_or_resume(self): - """load or resume checkpoint.""" + """Load or resume checkpoint.""" if self._has_loaded: return None @@ -1539,7 +1539,7 @@ def save_checkpoint( file_client_args: Optional[dict] = None, save_optimizer: bool = True, save_param_scheduler: bool = True, - meta: dict = None, + meta: Optional[dict] = None, by_epoch: bool = True, backend_args: Optional[dict] = None, ): diff --git a/mmengine/runner/checkpoint.py b/mmengine/runner/checkpoint.py index 60d71a735b..2bf5f50f7c 100644 --- a/mmengine/runner/checkpoint.py +++ b/mmengine/runner/checkpoint.py @@ -309,7 +309,7 @@ def _get_checkpoint_loader(cls, path): @classmethod def load_checkpoint(cls, filename, map_location=None, logger='current'): - """load checkpoint through URL scheme path. + """Load checkpoint through URL scheme path. Args: filename (str): checkpoint file name with given prefix @@ -332,7 +332,7 @@ def load_checkpoint(cls, filename, map_location=None, logger='current'): @CheckpointLoader.register_scheme(prefixes='') def load_from_local(filename, map_location): - """load checkpoint by local file path. + """Load checkpoint by local file path. Args: filename (str): local checkpoint file path @@ -353,7 +353,7 @@ def load_from_http(filename, map_location=None, model_dir=None, progress=os.isatty(0)): - """load checkpoint through HTTP or HTTPS scheme path. In distributed + """Load checkpoint through HTTP or HTTPS scheme path. In distributed setting, this function only download checkpoint at local rank 0. Args: @@ -386,7 +386,7 @@ def load_from_http(filename, @CheckpointLoader.register_scheme(prefixes='pavi://') def load_from_pavi(filename, map_location=None): - """load checkpoint through the file path prefixed with pavi. In distributed + """Load checkpoint through the file path prefixed with pavi. In distributed setting, this function download ckpt at all ranks to different temporary directories. @@ -419,7 +419,7 @@ def load_from_pavi(filename, map_location=None): @CheckpointLoader.register_scheme( prefixes=[r'(\S+\:)?s3://', r'(\S+\:)?petrel://']) def load_from_ceph(filename, map_location=None, backend='petrel'): - """load checkpoint through the file path prefixed with s3. In distributed + """Load checkpoint through the file path prefixed with s3. In distributed setting, this function download ckpt at all ranks to different temporary directories. @@ -441,7 +441,7 @@ def load_from_ceph(filename, map_location=None, backend='petrel'): @CheckpointLoader.register_scheme(prefixes=('modelzoo://', 'torchvision://')) def load_from_torchvision(filename, map_location=None): - """load checkpoint through the file path prefixed with modelzoo or + """Load checkpoint through the file path prefixed with modelzoo or torchvision. Args: @@ -467,7 +467,7 @@ def load_from_torchvision(filename, map_location=None): @CheckpointLoader.register_scheme(prefixes=('open-mmlab://', 'openmmlab://')) def load_from_openmmlab(filename, map_location=None): - """load checkpoint through the file path prefixed with open-mmlab or + """Load checkpoint through the file path prefixed with open-mmlab or openmmlab. Args: @@ -510,7 +510,7 @@ def load_from_openmmlab(filename, map_location=None): @CheckpointLoader.register_scheme(prefixes='mmcls://') def load_from_mmcls(filename, map_location=None): - """load checkpoint through the file path prefixed with mmcls. + """Load checkpoint through the file path prefixed with mmcls. Args: filename (str): checkpoint file path with mmcls prefix diff --git a/mmengine/runner/runner.py b/mmengine/runner/runner.py index 68716ab253..7d1f655aad 100644 --- a/mmengine/runner/runner.py +++ b/mmengine/runner/runner.py @@ -579,7 +579,7 @@ def timestamp(self): @property def hooks(self): - """list[:obj:`Hook`]: A list of registered hooks.""" + """List[:obj:`Hook`]: A list of registered hooks.""" return self._hooks @property @@ -720,7 +720,7 @@ def set_randomness(self, def build_logger(self, log_level: Union[int, str] = 'INFO', - log_file: str = None, + log_file: Optional[str] = None, **kwargs) -> MMLogger: """Build a global asscessable MMLogger. @@ -1677,7 +1677,7 @@ def get_hooks_info(self) -> str: return '\n'.join(stage_hook_infos) def load_or_resume(self) -> None: - """load or resume checkpoint.""" + """Load or resume checkpoint.""" if self._has_loaded: return None diff --git a/mmengine/structures/base_data_element.py b/mmengine/structures/base_data_element.py index 53bcd5babf..8ac5a3d27d 100644 --- a/mmengine/structures/base_data_element.py +++ b/mmengine/structures/base_data_element.py @@ -387,7 +387,7 @@ def metainfo(self) -> dict: return dict(self.metainfo_items()) def __setattr__(self, name: str, value: Any): - """setattr is only used to set data.""" + """Setattr is only used to set data.""" if name in ('_metainfo_fields', '_data_fields'): if not hasattr(self, name): super().__setattr__(name, value) diff --git a/mmengine/structures/instance_data.py b/mmengine/structures/instance_data.py index 369d445f28..8633b86037 100644 --- a/mmengine/structures/instance_data.py +++ b/mmengine/structures/instance_data.py @@ -135,7 +135,7 @@ class InstanceData(BaseDataElement): """ def __setattr__(self, name: str, value: Sized): - """setattr is only used to set data. + """Setattr is only used to set data. The value must have the attribute of `__len__` and have the same length of `InstanceData`. diff --git a/mmengine/utils/dl_utils/hub.py b/mmengine/utils/dl_utils/hub.py index cf555ac766..7f7f1a087d 100644 --- a/mmengine/utils/dl_utils/hub.py +++ b/mmengine/utils/dl_utils/hub.py @@ -57,6 +57,7 @@ def load_url(url, check_hash=False, file_name=None): r"""Loads the Torch serialized object at the given URL. + If downloaded file is a zip file, it will be automatically decompressed If the object is already present in `model_dir`, it's deserialized and returned. diff --git a/mmengine/utils/dl_utils/time_counter.py b/mmengine/utils/dl_utils/time_counter.py index e4a155dd72..05c008da45 100644 --- a/mmengine/utils/dl_utils/time_counter.py +++ b/mmengine/utils/dl_utils/time_counter.py @@ -67,7 +67,7 @@ def __new__(cls, instance.log_interval = log_interval instance.warmup_interval = warmup_interval - instance.with_sync = with_sync + instance.with_sync = with_sync # type: ignore instance.tag = tag instance.logger = logger @@ -127,7 +127,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): self.print_time(elapsed) def print_time(self, elapsed: Union[int, float]) -> None: - """print times per count.""" + """Print times per count.""" if self.__count >= self.warmup_interval: self.__pure_inf_time += elapsed diff --git a/mmengine/utils/misc.py b/mmengine/utils/misc.py index 15c1f89fae..23ae707a56 100644 --- a/mmengine/utils/misc.py +++ b/mmengine/utils/misc.py @@ -131,7 +131,7 @@ def tuple_cast(inputs, dst_type): def is_seq_of(seq: Any, expected_type: Union[Type, tuple], - seq_type: Type = None) -> bool: + seq_type: Optional[Type] = None) -> bool: """Check whether it is a sequence of some type. Args: diff --git a/mmengine/utils/package_utils.py b/mmengine/utils/package_utils.py index b224625f13..1816f47f07 100644 --- a/mmengine/utils/package_utils.py +++ b/mmengine/utils/package_utils.py @@ -69,11 +69,11 @@ def get_installed_path(package: str) -> str: else: raise e - possible_path = osp.join(pkg.location, package) + possible_path = osp.join(pkg.location, package) # type: ignore if osp.exists(possible_path): return possible_path else: - return osp.join(pkg.location, package2module(package)) + return osp.join(pkg.location, package2module(package)) # type: ignore def package2module(package: str): diff --git a/mmengine/utils/progressbar.py b/mmengine/utils/progressbar.py index 36172f04dd..47e710603b 100644 --- a/mmengine/utils/progressbar.py +++ b/mmengine/utils/progressbar.py @@ -3,7 +3,7 @@ from collections.abc import Iterable from multiprocessing import Pool from shutil import get_terminal_size -from typing import Callable, Sequence +from typing import Callable, Optional, Sequence from .timer import Timer @@ -54,7 +54,7 @@ def start(self): self.timer = Timer() def update(self, num_tasks: int = 1): - """update progressbar. + """Update progressbar. Args: num_tasks (int): Update step size. @@ -142,8 +142,8 @@ def init_pool(process_num, initializer=None, initargs=None): def track_parallel_progress(func: Callable, tasks: Sequence, nproc: int, - initializer: Callable = None, - initargs: tuple = None, + initializer: Optional[Callable] = None, + initargs: Optional[tuple] = None, bar_width: int = 50, chunksize: int = 1, skip_first: bool = False, diff --git a/mmengine/utils/progressbar_rich.py b/mmengine/utils/progressbar_rich.py index c126866ba9..f8e04d8041 100644 --- a/mmengine/utils/progressbar_rich.py +++ b/mmengine/utils/progressbar_rich.py @@ -1,6 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from multiprocessing import Pool -from typing import Callable, Iterable, Sized +from typing import Callable, Iterable, Optional, Sized from rich.progress import (BarColumn, MofNCompleteColumn, Progress, Task, TaskProgressColumn, TextColumn, TimeRemainingColumn) @@ -47,7 +47,7 @@ def _tasks_with_index(tasks): def track_progress_rich(func: Callable, tasks: Iterable = tuple(), - task_num: int = None, + task_num: Optional[int] = None, nproc: int = 1, chunksize: int = 1, description: str = 'Processing', diff --git a/mmengine/visualization/vis_backend.py b/mmengine/visualization/vis_backend.py index f74eab1fcd..b752ec85a7 100644 --- a/mmengine/visualization/vis_backend.py +++ b/mmengine/visualization/vis_backend.py @@ -161,7 +161,7 @@ def add_scalars(self, pass def close(self) -> None: - """close an opened object.""" + """Close an opened object.""" pass @@ -314,7 +314,7 @@ def add_scalars(self, def _dump(self, value_dict: dict, file_path: str, file_format: str) -> None: - """dump dict to file. + """Dump dict to file. Args: value_dict (dict) : The dict data to saved. @@ -505,7 +505,7 @@ def add_scalars(self, self._wandb.log(scalar_dict, commit=self._commit) def close(self) -> None: - """close an opened wandb object.""" + """Close an opened wandb object.""" if hasattr(self, '_wandb'): self._wandb.join() @@ -629,7 +629,7 @@ def add_scalars(self, self.add_scalar(key, value, step) def close(self): - """close an opened tensorboard object.""" + """Close an opened tensorboard object.""" if hasattr(self, '_tensorboard'): self._tensorboard.close() @@ -1135,7 +1135,7 @@ def add_scalars(self, self._neptune[k].append(v, step=step) def close(self) -> None: - """close an opened object.""" + """Close an opened object.""" if hasattr(self, '_neptune'): self._neptune.stop() @@ -1282,7 +1282,7 @@ def add_scalars(self, self.add_scalar(key, value, step, **kwargs) def close(self) -> None: - """close an opened dvclive object.""" + """Close an opened dvclive object.""" if not hasattr(self, '_dvclive'): return diff --git a/mmengine/visualization/visualizer.py b/mmengine/visualization/visualizer.py index 0e90c184a6..6979395aca 100644 --- a/mmengine/visualization/visualizer.py +++ b/mmengine/visualization/visualizer.py @@ -356,7 +356,7 @@ def _init_manager(self, win_name: str) -> None: @master_only def get_backend(self, name) -> 'BaseVisBackend': - """get vis backend by name. + """Get vis backend by name. Args: name (str): The name of vis backend @@ -1145,7 +1145,7 @@ def add_datasample(self, pass def close(self) -> None: - """close an opened object.""" + """Close an opened object.""" for vis_backend in self._vis_backends.values(): vis_backend.close() diff --git a/tests/test_config/test_config.py b/tests/test_config/test_config.py index 905485c16a..e783431441 100644 --- a/tests/test_config/test_config.py +++ b/tests/test_config/test_config.py @@ -843,8 +843,8 @@ def _merge_delete(self): assert cfg_dict['item4'] == 'test' assert '_delete_' not in cfg_dict['item1'] - assert type(cfg_dict['item1']) == ConfigDict - assert type(cfg_dict['item2']) == ConfigDict + assert type(cfg_dict['item1']) is ConfigDict + assert type(cfg_dict['item2']) is ConfigDict def _merge_intermediate_variable(self): diff --git a/tests/test_fileio/test_backends/test_petrel_backend.py b/tests/test_fileio/test_backends/test_petrel_backend.py index ef2f85383c..6f379c3f23 100644 --- a/tests/test_fileio/test_backends/test_petrel_backend.py +++ b/tests/test_fileio/test_backends/test_petrel_backend.py @@ -300,8 +300,8 @@ def get(filepath): get_inputs.append(filepath) with build_temporary_directory() as tmp_dir, \ - patch.object(backend, 'put', side_effect=put),\ - patch.object(backend, 'get', side_effect=get),\ + patch.object(backend, 'put', side_effect=put), \ + patch.object(backend, 'get', side_effect=get), \ patch.object(backend, 'exists', return_value=False): tmp_dir = tmp_dir.replace('\\', '/') dst = f'{tmp_dir}/dir' @@ -351,7 +351,7 @@ def copyfile_from_local(src, dst): with build_temporary_directory() as tmp_dir, \ patch.object(backend, 'copyfile_from_local', - side_effect=copyfile_from_local),\ + side_effect=copyfile_from_local), \ patch.object(backend, 'exists', return_value=False): backend.copytree_from_local(tmp_dir, self.petrel_dir) @@ -427,7 +427,7 @@ def test_rmtree(self): def remove(filepath): inputs.append(filepath) - with build_temporary_directory() as tmp_dir,\ + with build_temporary_directory() as tmp_dir, \ patch.object(backend, 'remove', side_effect=remove): backend.rmtree(tmp_dir) diff --git a/tests/test_model/test_averaged_model.py b/tests/test_model/test_averaged_model.py index e3ef1c292d..6438b8bde5 100644 --- a/tests/test_model/test_averaged_model.py +++ b/tests/test_model/test_averaged_model.py @@ -13,7 +13,8 @@ class TestAveragedModel(TestCase): """Test the AveragedModel class. - Some test cases are referenced from https://github.com/pytorch/pytorch/blob/master/test/test_optim.py + Some test cases are referenced from + https://github.com/pytorch/pytorch/blob/master/test/test_optim.py """ # noqa: E501 def _test_swa_model(self, net_device, avg_device): diff --git a/tests/test_optim/test_scheduler/test_lr_scheduler.py b/tests/test_optim/test_scheduler/test_lr_scheduler.py index c9cd6e1fe6..22787e4709 100644 --- a/tests/test_optim/test_scheduler/test_lr_scheduler.py +++ b/tests/test_optim/test_scheduler/test_lr_scheduler.py @@ -102,7 +102,7 @@ def test_resume(self): rtol=0) def test_scheduler_before_optim_warning(self): - """warns if scheduler is used before optimizer.""" + """Warns if scheduler is used before optimizer.""" def call_sch_before_optim(): scheduler = StepLR(self.optimizer, gamma=0.1, step_size=3) diff --git a/tests/test_optim/test_scheduler/test_momentum_scheduler.py b/tests/test_optim/test_scheduler/test_momentum_scheduler.py index 942259d7da..60a9713ee2 100644 --- a/tests/test_optim/test_scheduler/test_momentum_scheduler.py +++ b/tests/test_optim/test_scheduler/test_momentum_scheduler.py @@ -120,7 +120,7 @@ def test_resume(self): rtol=0) def test_scheduler_before_optim_warning(self): - """warns if scheduler is used before optimizer.""" + """Warns if scheduler is used before optimizer.""" def call_sch_before_optim(): scheduler = StepMomentum(self.optimizer, gamma=0.1, step_size=3) diff --git a/tests/test_optim/test_scheduler/test_param_scheduler.py b/tests/test_optim/test_scheduler/test_param_scheduler.py index 557e04d9cd..a13072dc6e 100644 --- a/tests/test_optim/test_scheduler/test_param_scheduler.py +++ b/tests/test_optim/test_scheduler/test_param_scheduler.py @@ -127,7 +127,7 @@ def test_resume(self): rtol=0) def test_scheduler_before_optim_warning(self): - """warns if scheduler is used before optimizer.""" + """Warns if scheduler is used before optimizer.""" def call_sch_before_optim(): scheduler = StepParamScheduler( diff --git a/tests/test_runner/test_checkpoint.py b/tests/test_runner/test_checkpoint.py index b846616428..4655a4c5da 100644 --- a/tests/test_runner/test_checkpoint.py +++ b/tests/test_runner/test_checkpoint.py @@ -251,7 +251,7 @@ def __init__(self): def _load_from_state_dict(self, state_dict, prefix, local_metadata, *args, **kwargs): - """load checkpoints.""" + """Load checkpoints.""" # Names of some parameters in has been changed. version = local_metadata.get('version', None) diff --git a/tests/test_runner/test_runner.py b/tests/test_runner/test_runner.py index c8a58e9c8a..e7668054bb 100644 --- a/tests/test_runner/test_runner.py +++ b/tests/test_runner/test_runner.py @@ -2226,7 +2226,7 @@ def warmup_iter(self, data_batch): @HOOKS.register_module(force=True) class TestWarmupHook(Hook): - """test custom train loop.""" + """Test custom train loop.""" def before_warmup_iter(self, runner, data_batch=None): before_warmup_iter_results.append('before') diff --git a/tests/test_structures/test_data_element.py b/tests/test_structures/test_data_element.py index 883ae401d4..1cb7cd1745 100644 --- a/tests/test_structures/test_data_element.py +++ b/tests/test_structures/test_data_element.py @@ -64,7 +64,7 @@ def setup_data(self): return metainfo, data def is_equal(self, x, y): - assert type(x) == type(y) + assert type(x) is type(y) if isinstance( x, (int, float, str, list, tuple, dict, set, BaseDataElement)): return x == y @@ -141,7 +141,7 @@ def test_new(self): # test new() with no arguments new_instances = instances.new() - assert type(new_instances) == type(instances) + assert type(new_instances) is type(instances) # After deepcopy, the address of new data'element will be same as # origin, but when change new data' element will not effect the origin # element and will have new address @@ -154,7 +154,7 @@ def test_new(self): # test new() with arguments metainfo, data = self.setup_data() new_instances = instances.new(metainfo=metainfo, **data) - assert type(new_instances) == type(instances) + assert type(new_instances) is type(instances) assert id(new_instances.gt_instances) != id(instances.gt_instances) _, new_data = self.setup_data() new_instances.set_data(new_data) @@ -168,7 +168,7 @@ def test_clone(self): metainfo, data = self.setup_data() instances = BaseDataElement(metainfo=metainfo, **data) new_instances = instances.clone() - assert type(new_instances) == type(instances) + assert type(new_instances) is type(instances) def test_set_metainfo(self): metainfo, _ = self.setup_data() diff --git a/tests/test_visualizer/test_visualizer.py b/tests/test_visualizer/test_visualizer.py index f247da5051..e4ababc637 100644 --- a/tests/test_visualizer/test_visualizer.py +++ b/tests/test_visualizer/test_visualizer.py @@ -45,7 +45,7 @@ def add_scalars(self, self._add_scalars = True def close(self) -> None: - """close an opened object.""" + """Close an opened object.""" self._close = True