Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Fix lint #1598

Merged
merged 3 commits into from
Nov 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Python 3.7
- name: Set up Python 3.10.15
uses: actions/setup-python@v2
with:
python-version: 3.7
python-version: '3.10.15'
- name: Install pre-commit hook
run: |
pip install pre-commit
Expand Down
4 changes: 4 additions & 0 deletions .github/workflows/pr_stage_test.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
name: pr_stage_test

env:
ACTIONS_ALLOW_USE_UNSECURE_NODE_VERSION: true


on:
pull_request:
paths-ignore:
Expand Down
13 changes: 9 additions & 4 deletions .pre-commit-config-zh-cn.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
exclude: ^tests/data/
repos:
- repo: https://gitee.com/openmmlab/mirrors-flake8
rev: 5.0.4
- repo: https://github.com/pre-commit/pre-commit
rev: v4.0.0
hooks:
- id: validate_manifest
- repo: https://github.com/PyCQA/flake8
rev: 7.1.1
hooks:
- id: flake8
- repo: https://gitee.com/openmmlab/mirrors-isort
Expand All @@ -13,7 +17,7 @@ repos:
hooks:
- id: yapf
- repo: https://gitee.com/openmmlab/mirrors-pre-commit-hooks
rev: v4.3.0
rev: v5.0.0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个 repo 下是不是没有这个分支

hooks:
- id: trailing-whitespace
- id: check-yaml
Expand Down Expand Up @@ -55,11 +59,12 @@ repos:
args: ["mmengine", "tests"]
- id: remove-improper-eol-in-cn-docs
- repo: https://gitee.com/openmmlab/mirrors-mypy
rev: v0.812
rev: v1.2.0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个 repo 下是不是没有这个分支

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

没有,要是有人用的话再去同步一下好了。我貌似没有权限

hooks:
- id: mypy
exclude: |-
(?x)(
^examples
| ^docs
)
additional_dependencies: ["types-setuptools", "types-requests", "types-PyYAML"]
17 changes: 9 additions & 8 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
exclude: ^tests/data/
repos:
- repo: https://github.com/pre-commit/pre-commit
rev: v4.0.0
hooks:
- id: validate_manifest
- repo: https://github.com/PyCQA/flake8
rev: 5.0.4
rev: 7.1.1
hooks:
- id: flake8
- repo: https://github.com/PyCQA/isort
Expand All @@ -13,7 +17,7 @@ repos:
hooks:
- id: yapf
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.3.0
rev: v5.0.0
hooks:
- id: trailing-whitespace
- id: check-yaml
Expand All @@ -34,12 +38,8 @@ repos:
- mdformat-openmmlab
- mdformat_frontmatter
- linkify-it-py
- repo: https://github.com/codespell-project/codespell
rev: v2.2.1
hooks:
- id: codespell
- repo: https://github.com/myint/docformatter
rev: v1.3.1
rev: 06907d0
hooks:
- id: docformatter
args: ["--in-place", "--wrap-descriptions", "79"]
Expand All @@ -55,11 +55,12 @@ repos:
args: ["mmengine", "tests"]
- id: remove-improper-eol-in-cn-docs
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v0.812
rev: v1.2.0
hooks:
- id: mypy
exclude: |-
(?x)(
^examples
| ^docs
)
additional_dependencies: ["types-setuptools", "types-requests", "types-PyYAML"]
2 changes: 1 addition & 1 deletion mmengine/_strategy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,7 +499,7 @@ def build_optim_wrapper(
'"type" and "constructor" are not in '
f'optimizer, but got {name}={optim}')
optim_wrappers[name] = optim
return OptimWrapperDict(**optim_wrappers)
return OptimWrapperDict(**optim_wrappers) # type: ignore
else:
raise TypeError('optimizer wrapper should be an OptimWrapper '
f'object or dict, but got {optim_wrapper}')
Expand Down
2 changes: 1 addition & 1 deletion mmengine/_strategy/colossalai.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ def resume(
map_location: Union[str, Callable] = 'default',
callback: Optional[Callable] = None,
) -> dict:
"""override this method since colossalai resume optimizer from filename
"""Override this method since colossalai resume optimizer from filename
directly."""
self.logger.info(f'Resume checkpoint from {filename}')

Expand Down
2 changes: 1 addition & 1 deletion mmengine/_strategy/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def _setup_distributed( # type: ignore
init_dist(launcher, backend, **kwargs)

def convert_model(self, model: nn.Module) -> nn.Module:
"""convert all ``BatchNorm`` layers in the model to ``SyncBatchNorm``
"""Convert all ``BatchNorm`` layers in the model to ``SyncBatchNorm``
(SyncBN) or ``mmcv.ops.sync_bn.SyncBatchNorm`` (MMSyncBN) layers.

Args:
Expand Down
31 changes: 16 additions & 15 deletions mmengine/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ class Config:

def __init__(
self,
cfg_dict: dict = None,
cfg_dict: Optional[dict] = None,
cfg_text: Optional[str] = None,
filename: Optional[Union[str, Path]] = None,
env_variables: Optional[dict] = None,
Expand Down Expand Up @@ -1227,7 +1227,8 @@ def is_base_line(c):
if base_code is not None:
base_code = ast.Expression( # type: ignore
body=base_code.value) # type: ignore
base_files = eval(compile(base_code, '', mode='eval'))
base_files = eval(compile(base_code, '',
mode='eval')) # type: ignore
else:
base_files = []
elif file_format in ('.yml', '.yaml', '.json'):
Expand Down Expand Up @@ -1288,7 +1289,7 @@ def _get_cfg_path(cfg_path: str,
def _merge_a_into_b(a: dict,
b: dict,
allow_list_keys: bool = False) -> dict:
"""merge dict ``a`` into dict ``b`` (non-inplace).
"""Merge dict ``a`` into dict ``b`` (non-inplace).

Values in ``a`` will overwrite ``b``. ``b`` is copied first to avoid
in-place modifications.
Expand Down Expand Up @@ -1358,22 +1359,22 @@ def auto_argparser(description=None):

@property
def filename(self) -> str:
"""get file name of config."""
"""Get file name of config."""
return self._filename

@property
def text(self) -> str:
"""get config text."""
"""Get config text."""
return self._text

@property
def env_variables(self) -> dict:
"""get used environment variables."""
"""Get used environment variables."""
return self._env_variables

@property
def pretty_text(self) -> str:
"""get formatted python config text."""
"""Get formatted python config text."""

indent = 4

Expand Down Expand Up @@ -1727,17 +1728,17 @@ def to_dict(self, keep_imported: bool = False):


class DictAction(Action):
"""
argparse action to split an argument into KEY=VALUE form
on the first = and append to a dictionary. List options can
be passed as comma separated values, i.e 'KEY=V1,V2,V3', or with explicit
brackets, i.e. 'KEY=[V1,V2,V3]'. It also support nested brackets to build
list/tuple values. e.g. 'KEY=[(V1,V2),(V3,V4)]'
"""Argparse action to split an argument into KEY=VALUE form on the first =
and append to a dictionary.

List options can be passed as comma separated values, i.e 'KEY=V1,V2,V3',
or with explicit brackets, i.e. 'KEY=[V1,V2,V3]'. It also support nested
brackets to build list/tuple values. e.g. 'KEY=[(V1,V2),(V3,V4)]'
"""

@staticmethod
def _parse_int_float_bool(val: str) -> Union[int, float, bool, Any]:
"""parse int/float/bool value in the string."""
"""Parse int/float/bool value in the string."""
try:
return int(val)
except ValueError:
Expand Down Expand Up @@ -1822,7 +1823,7 @@ def __call__(self,
parser: ArgumentParser,
namespace: Namespace,
values: Union[str, Sequence[Any], None],
option_string: str = None):
option_string: str = None): # type: ignore
"""Parse Variables in string and add them into argparser.

Args:
Expand Down
2 changes: 1 addition & 1 deletion mmengine/dist/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,7 +563,7 @@ def cast_data_device(
Tensor or list or dict: ``data`` was casted to ``device``.
"""
if out is not None:
if type(data) != type(out):
if type(data) is not type(out):
raise TypeError(
'out should be the same type with data, but got data is '
f'{type(data)} and out is {type(data)}')
Expand Down
6 changes: 3 additions & 3 deletions mmengine/evaluator/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,11 +175,11 @@ def __init__(self,
self.out_file_path = out_file_path

def process(self, data_batch: Any, predictions: Sequence[dict]) -> None:
"""transfer tensors in predictions to CPU."""
"""Transfer tensors in predictions to CPU."""
self.results.extend(_to_cpu(predictions))

def compute_metrics(self, results: list) -> dict:
"""dump the prediction results to a pickle file."""
"""Dump the prediction results to a pickle file."""
dump(results, self.out_file_path)
print_log(
f'Results has been saved to {self.out_file_path}.',
Expand All @@ -188,7 +188,7 @@ def compute_metrics(self, results: list) -> dict:


def _to_cpu(data: Any) -> Any:
"""transfer all tensors and BaseDataElement to cpu."""
"""Transfer all tensors and BaseDataElement to cpu."""
if isinstance(data, (Tensor, BaseDataElement)):
return data.to('cpu')
elif isinstance(data, list):
Expand Down
2 changes: 1 addition & 1 deletion mmengine/hooks/profiler_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def after_train_epoch(self, runner):
self._export_chrome_trace(runner)

def after_train_iter(self, runner, batch_idx, data_batch, outputs):
"""profiler will call `step` method if it is not closed."""
"""Profiler will call `step` method if it is not closed."""
if not self._closed:
self.profiler.step()
if runner.iter == self.profile_times - 1 and not self.by_epoch:
Expand Down
2 changes: 1 addition & 1 deletion mmengine/logging/history_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def _set_default_statistics(self) -> None:
self._statistics_methods.setdefault('mean', HistoryBuffer.mean)

def update(self, log_val: Union[int, float], count: int = 1) -> None:
"""update the log history.
"""Update the log history.

If the length of the buffer exceeds ``self._max_length``, the oldest
element will be removed from the buffer.
Expand Down
14 changes: 7 additions & 7 deletions mmengine/model/base_model/data_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,17 +253,17 @@ def forward(self, data: dict, training: bool = False) -> Union[dict, list]:
dict or list: Data in the same format as the model input.
"""
data = self.cast_data(data) # type: ignore
_batch_inputs = data['inputs']
_batch_inputs = data['inputs'] # type: ignore
# Process data with `pseudo_collate`.
if is_seq_of(_batch_inputs, torch.Tensor):
batch_inputs = []
for _batch_input in _batch_inputs:
# channel transform
if self._channel_conversion:
_batch_input = _batch_input[[2, 1, 0], ...]
_batch_input = _batch_input[[2, 1, 0], ...] # type: ignore
# Convert to float after channel conversion to ensure
# efficiency
_batch_input = _batch_input.float()
_batch_input = _batch_input.float() # type: ignore
# Normalization.
if self._enable_normalize:
if self.mean.shape[0] == 3:
Expand Down Expand Up @@ -302,7 +302,7 @@ def forward(self, data: dict, training: bool = False) -> Union[dict, list]:
else:
raise TypeError('Output of `cast_data` should be a dict of '
'list/tuple with inputs and data_samples, '
f'but got {type(data)}: {data}')
data['inputs'] = batch_inputs
data.setdefault('data_samples', None)
return data
f'but got {type(data)}: {data}') # type: ignore
data['inputs'] = batch_inputs # type: ignore
data.setdefault('data_samples', None) # type: ignore
return data # type: ignore
14 changes: 7 additions & 7 deletions mmengine/model/weight_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def caffe2_xavier_init(module, bias=0):


def bias_init_with_prob(prior_prob):
"""initialize conv/fc bias value according to a given probability value."""
"""Initialize conv/fc bias value according to a given probability value."""
bias_init = float(-np.log((1 - prior_prob) / prior_prob))
return bias_init

Expand Down Expand Up @@ -662,12 +662,12 @@ def trunc_normal_(tensor: Tensor,
std: float = 1.,
a: float = -2.,
b: float = 2.) -> Tensor:
r"""Fills the input Tensor with values drawn from a truncated
normal distribution. The values are effectively drawn from the
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
with values outside :math:`[a, b]` redrawn until they are within
the bounds. The method used for generating the random values works
best when :math:`a \leq \text{mean} \leq b`.
r"""Fills the input Tensor with values drawn from a truncated normal
distribution. The values are effectively drawn from the normal distribution
:math:`\mathcal{N}(\text{mean}, \text{std}^2)` with values outside
:math:`[a, b]` redrawn until they are within the bounds. The method used
for generating the random values works best when :math:`a \leq \text{mean}
\leq b`.

Modified from
https://github.com/pytorch/pytorch/blob/master/torch/nn/init.py
Expand Down
7 changes: 4 additions & 3 deletions mmengine/model/wrappers/fully_sharded_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,8 @@ def __init__(
auto_wrap_policy: Union[str, Callable, None] = None,
backward_prefetch: Union[str, BackwardPrefetch, None] = None,
mixed_precision: Union[dict, MixedPrecision, None] = None,
param_init_fn: Union[str, Callable[[nn.Module], None]] = None,
param_init_fn: Union[str, Callable[
[nn.Module], None]] = None, # type: ignore # noqa: E501
use_orig_params: bool = True,
**kwargs,
):
Expand Down Expand Up @@ -362,7 +363,7 @@ def optim_state_dict(
optim: torch.optim.Optimizer,
group: Optional[dist.ProcessGroup] = None,
) -> Dict[str, Any]:
"""copied from pytorch 2.0.1 which has fixed some bugs."""
"""Copied from pytorch 2.0.1 which has fixed some bugs."""
state_dict_settings = FullyShardedDataParallel.get_state_dict_type(
model)
return FullyShardedDataParallel._optim_state_dict_impl(
Expand All @@ -384,7 +385,7 @@ def set_state_dict_type(
state_dict_config: Optional[StateDictConfig] = None,
optim_state_dict_config: Optional[OptimStateDictConfig] = None,
) -> StateDictSettings:
"""copied from pytorch 2.0.1 which has fixed some bugs."""
"""Copied from pytorch 2.0.1 which has fixed some bugs."""
import torch.distributed.fsdp._traversal_utils as traversal_utils
_state_dict_type_to_config = {
StateDictType.FULL_STATE_DICT: FullStateDictConfig,
Expand Down
3 changes: 1 addition & 2 deletions mmengine/optim/optimizer/apex_optimizer_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,7 @@ def backward(self, loss: torch.Tensor, **kwargs) -> None:
self._inner_count += 1

def state_dict(self) -> dict:
"""Get the state dictionary of :attr:`optimizer` and
:attr:`apex_amp`.
"""Get the state dictionary of :attr:`optimizer` and :attr:`apex_amp`.

Based on the state dictionary of the optimizer, the returned state
dictionary will add a key named "apex_amp".
Expand Down
4 changes: 2 additions & 2 deletions mmengine/optim/optimizer/default_constructor.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def __init__(self,
self._validate_cfg()

def _validate_cfg(self) -> None:
"""verify the correctness of the config."""
"""Verify the correctness of the config."""
if not isinstance(self.paramwise_cfg, dict):
raise TypeError('paramwise_cfg should be None or a dict, '
f'but got {type(self.paramwise_cfg)}')
Expand All @@ -155,7 +155,7 @@ def _validate_cfg(self) -> None:
raise ValueError('base_wd should not be None')

def _is_in(self, param_group: dict, param_group_list: list) -> bool:
"""check whether the `param_group` is in the`param_group_list`"""
"""Check whether the `param_group` is in the`param_group_list`"""
assert is_list_of(param_group_list, dict)
param = set(param_group['params'])
param_set = set()
Expand Down
Loading
Loading