Skip to content

Commit

Permalink
Engines update (#1400)
Browse files Browse the repository at this point in the history
* Engines update

* +

* +
  • Loading branch information
Scitator authored Feb 7, 2022
1 parent 4ca33a0 commit 4e8e77f
Show file tree
Hide file tree
Showing 38 changed files with 4,784 additions and 482 deletions.
1 change: 1 addition & 0 deletions .github/PULL_REQUEST_TEMPLATE.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,5 +40,6 @@ If we didn't discuss your PR in Github issues there's a high chance it will not
- [ ] Have you run [colab minimal CI/CD](https://colab.research.google.com/github/catalyst-team/catalyst/blob/master/examples/notebooks/colab_ci_cd.ipynb) with `latest` requirements? Please attach the notebook link.
- [ ] Have you run [colab minimal CI/CD](https://colab.research.google.com/github/catalyst-team/catalyst/blob/master/examples/notebooks/colab_ci_cd.ipynb) with `minimal` requirements? Please attach the notebook link.
- [ ] Have you checked [XLA integration](https://colab.research.google.com/github/catalyst-team/catalyst/blob/master/examples/notebooks/XLA.ipynb)? Please attach the notebook link.
- [ ] Have you checked [distributed XLA integration](https://colab.research.google.com/github/catalyst-team/catalyst/blob/master/examples/notebooks/XLA_ddp.ipynb)? Please attach the notebook link.

<!-- For CHANGELOG separate each item in unreleased section by blank line to reduce collisions -->
2 changes: 1 addition & 1 deletion catalyst/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# callback
# logger

from catalyst.core.engine import IEngine
from catalyst.core.engine import Engine
from catalyst.core.runner import IRunner, IRunnerError
from catalyst.core.callback import (
ICallback,
Expand Down
2 changes: 1 addition & 1 deletion catalyst/core/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ class Callback(ICallback):
To learn more about Catalyst Core concepts, please check out
- :py:mod:`catalyst.core.runner.IRunner`
- :py:mod:`catalyst.core.engine.IEngine`
- :py:mod:`catalyst.core.engine.Engine`
- :py:mod:`catalyst.core.callback.Callback`
"""
Expand Down
8 changes: 4 additions & 4 deletions catalyst/core/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import torch_xla.core.xla_model as xm


class IEngine(Accelerator):
class Engine(Accelerator):
"""
An abstraction that syncs experiment run with
different hardware-specific configurations.
Expand All @@ -26,7 +26,7 @@ class IEngine(Accelerator):
- DDP (deepspeed, torch)
- XLA
Abstraction, please check out implementations for more details:
Please check out implementations for more details:
- :py:mod:`catalyst.engines.torch.CPUEngine`
- :py:mod:`catalyst.engines.torch.GPUEngine`
- :py:mod:`catalyst.engines.torch.DataParallelEngine`
Expand Down Expand Up @@ -72,7 +72,7 @@ def cleanup(self):
"""Cleans DDP variables and processes."""
pass

def mean_reduce_ddp_metrics(self, metrics: Dict):
def mean_reduce_ddp_metrics(self, metrics: Dict) -> Dict:
"""Syncs ``metrics`` over ``world_size`` in the distributed mode."""
if self.state.distributed_type in [
DistributedType.MULTI_CPU,
Expand All @@ -95,4 +95,4 @@ def mean_reduce_ddp_metrics(self, metrics: Dict):
return metrics


__all__ = ["IEngine"]
__all__ = ["Engine"]
12 changes: 6 additions & 6 deletions catalyst/core/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from torch.utils.data import DataLoader, DistributedSampler

from catalyst.core.callback import Callback, ICallback
from catalyst.core.engine import IEngine
from catalyst.core.engine import Engine
from catalyst.core.logger import ILogger
from catalyst.core.misc import (
check_callbacks,
Expand Down Expand Up @@ -45,7 +45,7 @@ class IRunner(ICallback, ILogger, ABC):
Args:
model: Torch model object
engine: IEngine instance
engine: Engine instance
Abstraction, please check out implementations for more details:
Expand All @@ -56,7 +56,7 @@ class IRunner(ICallback, ILogger, ABC):
To learn more about Catalyst Core concepts, please check out
- :py:mod:`catalyst.core.runner.IRunner`
- :py:mod:`catalyst.core.engine.IEngine`
- :py:mod:`catalyst.core.engine.Engine`
- :py:mod:`catalyst.core.callback.Callback`
.. note::
Expand All @@ -66,9 +66,9 @@ class IRunner(ICallback, ILogger, ABC):
"""

def __init__(self, model: RunnerModel = None, engine: IEngine = None):
def __init__(self, model: RunnerModel = None, engine: Engine = None):
"""Init."""
self.engine: IEngine = engine
self.engine: Engine = engine
self.loggers: Dict[str, ILogger] = {}
self.loaders: Dict[str, DataLoader] = None
self.model: RunnerModel = model
Expand Down Expand Up @@ -169,7 +169,7 @@ def close_log(self) -> None:
logger.close_log()

@abstractmethod
def get_engine(self) -> IEngine:
def get_engine(self) -> Engine:
"""Returns the engine for the experiment."""
pass

Expand Down
7 changes: 3 additions & 4 deletions catalyst/engines/__init__.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,19 @@
# flake8: noqa

from catalyst.core.engine import IEngine
from catalyst.core.engine import Engine

from catalyst.engines.torch import (
CPUEngine,
GPUEngine,
DeviceEngine,
Engine,
DataParallelEngine,
DistributedDataParallelEngine,
)

__all__ = [
"IEngine",
"Engine",
"CPUEngine",
"GPUEngine",
"DeviceEngine",
"DataParallelEngine",
"DistributedDataParallelEngine",
]
Expand Down
64 changes: 43 additions & 21 deletions catalyst/engines/torch.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# taken from https://github.com/Scitator/animus/blob/main/animus/torch/accelerate.py
from typing import Callable, Dict
from typing import Any, Callable, Dict, Optional, Union
import os

import numpy as np
Expand All @@ -9,39 +9,31 @@
import torch.multiprocessing as mp

from catalyst import SETTINGS
from catalyst.core.engine import IEngine
from catalyst.core.engine import Engine
from catalyst.utils.distributed import mean_reduce

if SETTINGS.xla_required:
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp


class DeviceEngine(IEngine):
"""Singe-device engine."""

def __init__(self, *args, **kwargs) -> None:
"""Init."""
super().__init__(*args, **kwargs)


class CPUEngine(IEngine):
class CPUEngine(Engine):
"""CPU-based engine."""

def __init__(self, *args, **kwargs) -> None:
"""Init."""
super().__init__(*args, cpu=True, **kwargs)


class GPUEngine(IEngine):
class GPUEngine(Engine):
"""Single-GPU-based engine."""

def __init__(self, *args, **kwargs) -> None:
"""Init."""
super().__init__(*args, cpu=False, **kwargs)


class DataParallelEngine(IEngine):
class DataParallelEngine(Engine):
"""Multi-GPU-based engine."""

def __init__(self, *args, **kwargs) -> None:
Expand All @@ -55,11 +47,38 @@ def prepare_model(self, model):
return model


class DistributedDataParallelEngine(IEngine):
"""Distributed multi-GPU-based engine."""

def __init__(self, *args, **kwargs):
class DistributedDataParallelEngine(Engine):
"""Distributed multi-GPU-based engine.
Args:
*args: args for Accelerator.__init__
address: master node (rank 0)'s address, should be either the IP address or the hostname
of node 0, for single node multi-proc training, can simply be 127.0.0.1
port: master node (rank 0)'s free port that needs to be used for communication
during distributed training
world_size: the number of processes to use for distributed training.
Should be less or equal to the number of GPUs
process_group_kwargs: parameters for `torch.distributed.init_process_group`.
More info here:
https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group # noqa: E501, W505
**kwargs: kwargs for Accelerator.__init__
"""

def __init__(
self,
*args,
address: str = "127.0.0.1",
port: Union[str, int] = 2112,
world_size: Optional[int] = None,
process_group_kwargs: Dict[str, Any] = None,
**kwargs
):
"""Init."""
self._address = os.environ.get("MASTER_ADDR", address)
self._port = os.environ.get("MASTER_PORT", port)
self._world_size = world_size
self._process_group_kwargs = process_group_kwargs or {}
self._args = args
self._kwargs = kwargs

Expand All @@ -80,7 +99,7 @@ def spawn(self, fn: Callable, *args, **kwargs):
Returns:
wrapped function (if needed).
"""
world_size: int = torch.cuda.device_count()
world_size: int = self._world_size or torch.cuda.device_count()
return mp.spawn(
fn,
args=(world_size,),
Expand All @@ -99,7 +118,10 @@ def setup(self, local_rank: int, world_size: int):
process_group_kwargs = {
"backend": "nccl",
"world_size": world_size,
**self._process_group_kwargs,
}
os.environ["MASTER_ADDR"] = str(self._address)
os.environ["MASTER_PORT"] = str(self._port)
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["RANK"] = str(local_rank)
os.environ["LOCAL_RANK"] = str(local_rank)
Expand All @@ -110,7 +132,7 @@ def cleanup(self):
"""Cleans DDP variables and processes."""
dist.destroy_process_group()

def mean_reduce_ddp_metrics(self, metrics: Dict):
def mean_reduce_ddp_metrics(self, metrics: Dict) -> Dict:
"""Syncs ``metrics`` over ``world_size`` in the distributed mode."""
metrics = {
k: mean_reduce(
Expand All @@ -122,7 +144,7 @@ def mean_reduce_ddp_metrics(self, metrics: Dict):
return metrics


class DistributedXLAEngine(IEngine):
class DistributedXLAEngine(Engine):
"""Distributed XLA-based engine."""

def __init__(self, *args, **kwargs):
Expand Down Expand Up @@ -160,7 +182,7 @@ def setup(self, local_rank: int, world_size: int):
"""
super().__init__(self, *self._args, **self._kwargs)

def mean_reduce_ddp_metrics(self, metrics: Dict):
def mean_reduce_ddp_metrics(self, metrics: Dict) -> Dict:
"""Syncs ``metrics`` over ``world_size`` in the distributed mode."""
metrics = {
k: xm.mesh_reduce(k, v.item() if isinstance(v, torch.Tensor) else v, np.mean)
Expand Down
10 changes: 5 additions & 5 deletions catalyst/runners/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from catalyst.callbacks.misc import CheckRunCallback, TimerCallback, TqdmCallback
from catalyst.callbacks.profiler import ProfilerCallback
from catalyst.core.callback import Callback
from catalyst.core.engine import IEngine
from catalyst.core.engine import Engine
from catalyst.core.logger import ILogger
from catalyst.core.misc import callback_isinstance, sort_callbacks_by_order
from catalyst.core.runner import IRunner, IRunnerError
Expand Down Expand Up @@ -182,7 +182,7 @@ def num_epochs(self) -> int:
"""Returns the number of epochs in the experiment."""
return self._num_epochs

def get_engine(self) -> IEngine:
def get_engine(self) -> Engine:
"""Returns the engine for the experiment."""
return self._engine

Expand Down Expand Up @@ -271,7 +271,7 @@ def train(
loaders: "OrderedDict[str, DataLoader]",
# the core
model: TorchModel,
engine: Union["IEngine", str] = None,
engine: Union["Engine", str] = None,
# the components
criterion: TorchCriterion = None,
optimizer: TorchOptimizer = None,
Expand Down Expand Up @@ -405,7 +405,7 @@ def predict_loader(
*,
loader: DataLoader,
model: TorchModel = None,
engine: Union["IEngine", str] = None,
engine: Union["Engine", str] = None,
seed: int = 42,
# extra info
resume: str = None,
Expand Down Expand Up @@ -458,7 +458,7 @@ def evaluate_loader(
loader: DataLoader,
callbacks: "Union[List[Callback], OrderedDict[str, Callback]]" = None,
model: Optional[TorchModel] = None,
engine: Union["IEngine", str] = None,
engine: Union["Engine", str] = None,
seed: int = 42,
verbose: bool = False,
) -> Dict[str, Any]:
Expand Down
6 changes: 3 additions & 3 deletions catalyst/runners/supervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
IOptimizerCallback,
ISchedulerCallback,
)
from catalyst.core.engine import IEngine
from catalyst.core.engine import Engine
from catalyst.core.misc import callback_isinstance, sort_callbacks_by_order
from catalyst.core.runner import IRunner
from catalyst.runners.runner import Runner
Expand Down Expand Up @@ -167,7 +167,7 @@ class SupervisedRunner(ISupervisedRunner, Runner):
Args:
model: Torch model instance
engine: IEngine instance
engine: Engine instance
input_key: key in ``runner.batch`` dict mapping for model input
output_key: key for ``runner.batch`` to store model output
target_key: key in ``runner.batch`` dict mapping for target
Expand All @@ -192,7 +192,7 @@ class SupervisedRunner(ISupervisedRunner, Runner):
def __init__(
self,
model: RunnerModel = None,
engine: IEngine = None,
engine: Engine = None,
input_key: Any = "features",
output_key: Any = "logits",
target_key: str = "targets",
Expand Down
6 changes: 3 additions & 3 deletions catalyst/utils/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import torch_xla.core.xla_model as xm

if TYPE_CHECKING:
from catalyst.core.engine import IEngine
from catalyst.core.engine import Engine


def get_optimizer_momentum(optimizer: TorchOptimizer) -> float:
Expand Down Expand Up @@ -92,7 +92,7 @@ def get_available_engine(
cpu: bool = False,
fp16: bool = False,
ddp: bool = False,
) -> "IEngine":
) -> "Engine":
"""Returns available engine based on given arguments.
Args:
Expand All @@ -101,7 +101,7 @@ def get_available_engine(
fp16 (bool): option to use APEX for training. Default is `False`.
Returns:
IEngine which match requirements.
Engine which match requirements.
"""
from catalyst.engines.torch import (
CPUEngine,
Expand Down
2 changes: 1 addition & 1 deletion docs/api/core.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ Runner

Engine
----------------------
.. autoclass:: catalyst.core.engine.IEngine
.. autoclass:: catalyst.core.engine.Engine
:members:
:exclude-members: __init__
:undoc-members:
Expand Down
7 changes: 0 additions & 7 deletions docs/api/engines.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,6 @@ GPUEngine
:undoc-members:
:show-inheritance:

DeviceEngine
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: catalyst.engines.torch.DeviceEngine
:exclude-members: __init__
:undoc-members:
:show-inheritance:

DataParallelEngine
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: catalyst.engines.torch.DataParallelEngine
Expand Down
Loading

0 comments on commit 4e8e77f

Please sign in to comment.