Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] distributed training #210

Open
wants to merge 4 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 18 additions & 8 deletions paddle3d/apis/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,11 @@ def __init__(self,
else:
raise RuntimeError('Config file should in yaml format!')

self.update(learning_rate=learning_rate,
batch_size=batch_size,
iters=iters,
epochs=epochs)
self.update(
learning_rate=learning_rate,
batch_size=batch_size,
iters=iters,
epochs=epochs)

def _update_dic(self, dic: Dict, base_dic: Dict):
'''Update config from dic based base_dic
Expand Down Expand Up @@ -120,7 +121,8 @@ def update(self,
learning_rate: Optional[float] = None,
batch_size: Optional[int] = None,
iters: Optional[int] = None,
epochs: Optional[int] = None):
epochs: Optional[int] = None,
fleet: Optional[bool] = None):
'''Update config'''

if learning_rate is not None:
Expand All @@ -135,6 +137,9 @@ def update(self,
if epochs is not None:
self.dic['epochs'] = epochs

if fleet is not None:
self.dic['fleet'] = fleet

@property
def batch_size(self) -> int:
return self.dic.get('batch_size', 1)
Expand Down Expand Up @@ -181,6 +186,10 @@ def model(self) -> paddle.nn.Layer:
def amp_config(self) -> int:
return self.dic.get('amp_cfg', None)

@property
def fleet(self) -> bool:
return self.dic.get('fleet', False)

@property
def train_dataset_config(self) -> Dict:
return self.dic.get('train_dataset', {}).copy()
Expand Down Expand Up @@ -282,8 +291,8 @@ def _load_object(self, obj: Generic, recursive: bool = True) -> Any:
if recursive:
params = {}
for key, val in dic.items():
params[key] = self._load_object(obj=val,
recursive=recursive)
params[key] = self._load_object(
obj=val, recursive=recursive)
else:
params = dic
try:
Expand Down Expand Up @@ -317,7 +326,8 @@ def to_dict(self) -> Dict:
'train_dataset': self.train_dataset,
'val_dataset': self.val_dataset,
'batch_size': self.batch_size,
'amp_cfg': self.amp_config
'amp_cfg': self.amp_config,
'fleet': self.fleet
})

return dic
23 changes: 8 additions & 15 deletions paddle3d/apis/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,6 @@ def training_step(model: paddle.nn.Layer,
scaler=None,
amp_cfg=dict()) -> dict:

if optimizer.__class__.__name__ == 'OneCycleAdam':
optimizer.before_iter(cur_iter - 1)

model.train()

if isinstance(model, paddle.DataParallel) and hasattr(model._layers, 'use_recompute') \
Expand Down Expand Up @@ -67,20 +64,16 @@ def training_step(model: paddle.nn.Layer,
loss = parse_losses(outputs['loss'])
loss.backward()

if optimizer.__class__.__name__ == 'OneCycleAdam':
optimizer.after_iter()
if scaler is not None:
scaler.step(optimizer)
scaler.update()
optimizer.clear_grad()
else:
if scaler is not None:
scaler.step(optimizer)
scaler.update()
optimizer.clear_grad()
else:
optimizer.step()
optimizer.step()

model.clear_gradients()
if isinstance(optimizer._learning_rate,
paddle.optimizer.lr.LRScheduler):
optimizer._learning_rate.step()
model.clear_gradients()
if isinstance(optimizer._learning_rate, paddle.optimizer.lr.LRScheduler):
optimizer._learning_rate.step()

with paddle.no_grad():
if paddle.distributed.is_initialized():
Expand Down
22 changes: 16 additions & 6 deletions paddle3d/apis/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from typing import Callable, Optional, Union

import paddle
from paddle.distributed import fleet
from visualdl import LogWriter

import paddle3d.env as env
Expand Down Expand Up @@ -121,7 +122,8 @@ def __init__(
checkpoint: Union[dict, CheckpointABC] = dict(),
scheduler: Union[dict, SchedulerABC] = dict(),
dataloader_fn: Union[dict, Callable] = dict(),
amp_cfg: Optional[dict] = None):
amp_cfg: Optional[dict] = None,
fleet: Optional[int] = False):

self.model = model
self.optimizer = optimizer
Expand Down Expand Up @@ -240,6 +242,7 @@ def set_lr_scheduler_iters_per_epoch(lr_scheduler,
logger.info(
'Use AMP train, AMP config: {}, Scaler config: {}'.format(
amp_cfg_, scaler_cfg_))
self.fleet = fleet

def train(self):
"""
Expand All @@ -260,11 +263,18 @@ def train(self):
self.model)

model = self.model
if env.nranks > 1:
if not paddle.distributed.parallel.parallel_helper._is_parallel_ctx_initialized(
):
paddle.distributed.init_parallel_env()
model = paddle.DataParallel(self.model)
if self.fleet:
strategy = fleet.DistributedStrategy()
strategy.find_unused_parameters = False
fleet.init(is_collective=True, strategy=strategy)
model = fleet.distributed_model(model)
self.optimizer = fleet.distributed_optimizer(self.optimizer)
else:
if env.nranks > 1:
if not paddle.distributed.parallel.parallel_helper._is_parallel_ctx_initialized(
):
paddle.distributed.init_parallel_env()
model = paddle.DataParallel(self.model)

losses_sum = defaultdict(float)
timer = Timer(iters=self.iters - self.cur_iter)
Expand Down
38 changes: 21 additions & 17 deletions paddle3d/models/optimizers/lr_schedulers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
Ths copyright of CaDDN is as follows:
Apache-2.0 license [see LICENSE for details].
"""

from functools import partial

import paddle
Expand All @@ -32,18 +33,21 @@

@manager.LR_SCHEDULERS.add_component
class OneCycleWarmupDecayLr(LRScheduler):

def __init__(self,
base_learning_rate,
lr_ratio_peak=10,
lr_ratio_trough=1e-4,
step_ratio_peak=0.4):
step_ratio_peak=0.4,
last_epoch=-1,
verbose=False):
self.base_learning_rate = base_learning_rate
self.lr_ratio_peak = lr_ratio_peak
self.lr_ratio_trough = lr_ratio_trough
self.step_ratio_peak = step_ratio_peak
self.lr_phases = [] # init lr_phases
self.anneal_func = annealing_cos
self.last_epoch = last_epoch
self.verbose = verbose

def before_run(self, max_iters):
"""before_run"""
Expand All @@ -54,8 +58,10 @@ def before_run(self, max_iters):
self.lr_ratio_trough
])

def get_lr(self, curr_iter):
def get_lr(self, curr_iter=None):
"""get_lr"""
if curr_iter is None:
curr_iter = self.last_epoch
for (start_iter, end_iter, lr_start_ratio,
lr_end_ratio) in self.lr_phases:
if start_iter <= curr_iter < end_iter:
Expand All @@ -66,7 +72,6 @@ def get_lr(self, curr_iter):


class LRSchedulerCycle(LRScheduler):

def __init__(self, total_step, lr_phases, mom_phases):

self.total_step = total_step
Expand All @@ -78,12 +83,12 @@ def __init__(self, total_step, lr_phases, mom_phases):
if isinstance(lambda_func, str):
lambda_func = eval(lambda_func)
if i < len(lr_phases) - 1:
self.lr_phases.append(
(int(start * total_step),
int(lr_phases[i + 1][0] * total_step), lambda_func))
self.lr_phases.append((int(start * total_step),
int(lr_phases[i + 1][0] * total_step),
lambda_func))
else:
self.lr_phases.append(
(int(start * total_step), total_step, lambda_func))
self.lr_phases.append((int(start * total_step), total_step,
lambda_func))
assert self.lr_phases[0][0] == 0
self.mom_phases = []
for i, (start, lambda_func) in enumerate(mom_phases):
Expand All @@ -92,19 +97,18 @@ def __init__(self, total_step, lr_phases, mom_phases):
if isinstance(lambda_func, str):
lambda_func = eval(lambda_func)
if i < len(mom_phases) - 1:
self.mom_phases.append(
(int(start * total_step),
int(mom_phases[i + 1][0] * total_step), lambda_func))
self.mom_phases.append((int(start * total_step),
int(mom_phases[i + 1][0] * total_step),
lambda_func))
else:
self.mom_phases.append(
(int(start * total_step), total_step, lambda_func))
self.mom_phases.append((int(start * total_step), total_step,
lambda_func))
assert self.mom_phases[0][0] == 0
super().__init__()


@manager.OPTIMIZERS.add_component
class OneCycle(LRSchedulerCycle):

def __init__(self, total_step, lr_max, moms, div_factor, pct_start):
self.lr_max = lr_max
self.moms = moms
Expand Down Expand Up @@ -154,8 +158,8 @@ def get_lr(self):
if self.last_epoch == 0:
return self.base_lr
else:
cur_epoch = (self.last_epoch +
self.warmup_iters) // self.iters_per_epoch
cur_epoch = (
self.last_epoch + self.warmup_iters) // self.iters_per_epoch
return annealing_cos(self.base_lr, self.eta_min,
cur_epoch / self.T_max)

Expand Down
10 changes: 9 additions & 1 deletion paddle3d/models/optimizers/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ def __init__(self,
self._grad_clip = self.optimizer._grad_clip
self.optimizer._grad_clip = None

@property
def _parameter_list(self):
return self.optimizer._parameter_list

def _set_beta1(self, beta1, pow):
"""_set_beta1"""
# currently support Adam and AdamW only
Expand All @@ -70,7 +74,7 @@ def before_run(self, max_iters):

def before_iter(self, curr_iter):
"""before_iter"""
lr = self._learning_rate.get_lr(curr_iter=curr_iter)
lr = self._learning_rate.get_lr(curr_iter)
self.optimizer.set_lr(lr)
beta1 = self.beta1.get_momentum(curr_iter=curr_iter)
self._set_beta1(beta1, pow=curr_iter + 1)
Expand Down Expand Up @@ -119,6 +123,10 @@ def after_iter(self):
self.optimizer.step()
self.optimizer.clear_grad()

def step(self):
self.before_iter(self._learning_rate.last_epoch + 1)
self.after_iter()

def set_state_dict(self, optimizer):
self.optimizer.set_state_dict(optimizer)

Expand Down
9 changes: 8 additions & 1 deletion tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,12 @@ def parse_args():
help='Config for quant model.',
default=None,
type=str)
parser.add_argument(
"--fleet",
dest='fleet',
help="Use fleet or not",
default=None,
type=bool)

return parser.parse_args()

Expand Down Expand Up @@ -154,7 +160,8 @@ def main(args):
learning_rate=args.learning_rate,
batch_size=args.batch_size,
iters=args.iters,
epochs=args.epochs)
epochs=args.epochs,
fleet=args.fleet)

if cfg.train_dataset is None:
raise RuntimeError(
Expand Down