Skip to content

Commit

Permalink
[AutoTuner] Add first verison of autotuner (#124)
Browse files Browse the repository at this point in the history
This PR adds autotuner module, which can be used with one click by
setting `action=auto_tune`, just like:
`python run.py --config-path ./examples/aquila/conf --config-name config
action=auto_tune`.
AutoTuner currently supports the search of all major parallel
strategies, including:

- data parallel
- tensor parallel
- pipeline parallel
- context parallel
- expert parallel
- recompute
- etc. 

AutoTuner is user-friendly, users can add auto_tuner fields on the basis
of training yaml to custom, such as follows:
```
auto_tuner:
  space:
    num_layers_per_virtual_pipeline_stage: [1]
    use_recompute: [false]
  control:
    max_time_per_task: 300
    train_iters: 5
    max_time: 600
```
Currently we implement a heuristic grid search algorithm with built-in
efficient pruning strategies based on historical results, and more
search algorithms will be added in the future, so users don't need to
care about these parts at present.

Where` space `is the search space, the user can customize the candidate
value of each dimension, if not defined, there will be a default value
by framework. We have the following search dimensions built in:

- data_parallel_size
- use_distributed_optimizer
- tensor_model_parallel_size
- sequence_parallel
- pipeline_model_parallel_size
- num_layers_per_virtual_pipeline_stage
- use_recompute
- recompute_method
- recompute_granularity
- recompute_num_layers
- micro_batch_size
- context_parallel_size
- expert_model_parallel_size

`control` is used to control the search process, such as the maximum
running time of each task, how many steps are run, the maximum running
time of autotuner, etc

When the auto tuner running, each task has a corresponding log
directory, and the results are summarized and sorted that users only
need to look at the csv to know the detailed data for task.

---------

Co-authored-by: caozhou <[email protected]>
  • Loading branch information
Caozhou1995 and caozhou authored Jun 6, 2024
1 parent fc82ffb commit ac373cb
Show file tree
Hide file tree
Showing 21 changed files with 1,921 additions and 70 deletions.
43 changes: 43 additions & 0 deletions examples/aquila/conf/config_auto_tuner.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
defaults:
- train: demo
- _self_

experiment:
exp_name: aquila2
exp_dir: ./outputs
task:
type: train
backend: megatron
entrypoint: ./flagscale/train/train_aquila.py
runner:
backend: torchrun
nnodes: 1
nproc_per_node: 8
envs:
CUDA_VISIBLE_DEVICES: 0,1,2,3,4,5,6,7
CUDA_DEVICE_MAX_CONNECTIONS: 1
auto_tuner:
space:
data_parallel_size: "auto"
use_distributed_optimizer: [true, false]
tensor_model_parallel_size: [2, 4, 8]
sequence_parallel: [true]
pipeline_model_parallel_size: "auto"
num_layers_per_virtual_pipeline_stage: [1]
context_parallel_size: "auto"
expert_model_parallel_size: [1]
micro_batch_size: "auto"
use_recompute: [true]
recompute_method: "auto"
recompute_granularity: "auto"
recompute_num_layers: "auto"
control:
max_time_per_task: 300
train_iters: 5
max_time: 600

action: run

hydra:
run:
dir: ${experiment.exp_dir}/hydra
2 changes: 1 addition & 1 deletion examples/aquila/conf/train/train_aquila_7b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -61,4 +61,4 @@ data:
vocab_file: ./examples/aquila/tokenizer/vocab.json
merge_file: ./examples/aquila/tokenizer/merges.txt
special_tokens_file: ./examples/aquila/tokenizer/special_tokens.txt
vocab_size: 100008
vocab_size: 100008
1 change: 1 addition & 0 deletions flagscale/auto_tuner/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .tuner import AutoTuner
97 changes: 97 additions & 0 deletions flagscale/auto_tuner/generate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import os
import copy


class Generator:

def __init__(self, config):
self.config = config
# TODO: Just a temporary solution, need to be configurated by user
if "args_mapping" in config.experiment.auto_tuner:
self.args_mapping = config.experiment.auto_tuner.args_mapping
else:
self.args_mapping = {
"data_parallel_size": "data_parallel_size",
"use_distributed_optimizer": "use_distributed_optimizer",
"tensor_model_parallel_size": "tensor_model_parallel_size",
"sequence_parallel": "sequence_parallel",
"pipeline_model_parallel_size": "pipeline_model_parallel_size",
"num_layers_per_virtual_pipeline_stage":
"num_layers_per_virtual_pipeline_stage",
"recompute_method": "recompute_method",
"recompute_granularity": "recompute_granularity",
"recompute_num_layers": "recompute_num_layers",
"micro_batch_size": "micro_batch_size",
"context_parallel_size": "context_parallel_size",
"expert_model_parallel_size": "expert_model_parallel_size",
}

def _set_value(self, strategy, config):
for key, value in self.args_mapping.items():
if key in ["micro_batch_size"]:
config.train.model[value] = strategy[key]
elif key in ["data_parallel_size"]:
continue
else:
if strategy[key] is None:
if value in config.train.system:
del config.train.system[value]
continue
config.train.system[value] = strategy[key]

def gen(self, strategy):
config = copy.deepcopy(self.config)
self._set_value(strategy, config)

# Logging interval should be 1
config.train.system.logging.log_interval = 1

# Set redict and tee
config.experiment.runner.tee = 3
config.experiment.runner.redirects = 3

# auto_tune should be true, it will not save ckpt when train ended and report memory every iteration
config.train.system.auto_tune = True

# Del lr_warmup_samples and train_samples to run megatron.
assert "optimizer" in config.train.model
assert "lr_scheduler" in config.train.model.optimizer
if "lr_warmup_samples" in config.train.model.optimizer.lr_scheduler:
del config.train.model.optimizer.lr_scheduler.lr_warmup_samples
# Del lr_decay_samples and train_samples to run megatron.
if "lr_decay_samples" in config.train.model.optimizer.lr_scheduler:
del config.train.model.optimizer.lr_scheduler.lr_decay_samples
# Del rampup_batch_size and train_samples to run megatron.
if "rampup_batch_size" in config.train.model.optimizer.lr_scheduler:
del config.train.model.optimizer.lr_scheduler.rampup_batch_size
# Del lr_decay_samples and train_samples to run megatron.
if "lr_warmup_fraction" in config.train.model.optimizer.lr_scheduler:
del config.train.model.optimizer.lr_scheduler.lr_warmup_fraction

if "train_samples" in config.train.model:
del config.train.model.train_samples

# Del checkpoint load
if "checkpoint" in config.train.system:
if "load" in config.train.system.checkpoint:
del config.train.system.checkpoint.load
if "save_interval" in config.train.system.checkpoint:
config.train.system.checkpoint.save_interval = 2000

# Set train_iters of each task
if "control" in config.experiment.auto_tuner:
config.train.model.train_iters = config.experiment.auto_tuner.control.get(
"train_iters", 5)
else:
config.train.model.train_iters = 5

# log dir
config.experiment.exp_dir = os.path.join(config.experiment.exp_dir,
"auto_tuner",
f"task_{strategy['idx']}")

return config

def gen_best_task(self, strategy, config):
self._set_value(strategy, config)
return config
1 change: 1 addition & 0 deletions flagscale/auto_tuner/prune/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .pruner import Pruner
185 changes: 185 additions & 0 deletions flagscale/auto_tuner/prune/history.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
import logging
from ..utils import beside

_HISTORY_BASED_PRUNE_FUNC = []
logger = logging.getLogger("FlagScale-AutoTuner")


def register(func):
def wrapper(*args, **kwargs):
return func(*args, **kwargs)

_HISTORY_BASED_PRUNE_FUNC.append(wrapper)
return wrapper


@register
def prune_by_micro_batch_size(config, strategy, history=[]):
"""Prune strategy by micro_batch_size, the rules are as follows:
1. If the micro_batch_size of current strategy is larger than that of history,
then prune it by memory.
2. If the micro_batch_size of current strategy is smaller than that of history,
then prune it by performancd.
"""
micro_batch_size = strategy["micro_batch_size"]
retrieval = beside(["micro_batch_size", "acc_step"], strategy, history)
if retrieval:
for item in retrieval:
# performance prune
if item["micro_batch_size"] > micro_batch_size and item[
"performance"]:
logger.info(
f"The strategy {strategy} has been pruned by micro_batch_size performance."
)
strategy["performance"] = item["performance"]
strategy["max_mem"] = item["max_mem"]
strategy["pruned"] = True
return True
# memory prune
if item["micro_batch_size"] < micro_batch_size and item[
"max_mem"] == "OOM":
logger.info(
f"The strategy {strategy} has been pruned by micro_batch_size memory."
)
strategy["max_mem"] = "OOM"
strategy["performance"] = None
strategy["pruned"] = True
return True
return False


@register
def prune_by_recompute(config, strategy, history=[]):
"""Prune strategy by recompute, the rules are as follows:
1. If current strategy is using recompute but one of history doesn't use recompute and it can run,
then prune it by performance.
2. If current strategy is not using recompute but one of history with recompute OOM,
then prune it by memory.
3. If the recompute method and granularity of current strategy are 'uniform' and 'full', and one of history are 'uniform' and 'full',
If the recompute num layers of current strategy is larger than that of history and history OOM, prune it by memory.
4. If the recompute method and granularity of current strategy are 'uniform' and 'full', and one of history are 'uniform' and 'full',
If the recompute num layers of current strategy is smaller than that of history and history can run, prune it by performance.
5. If the recompute method and granularity of current strategy are 'block' and 'full', and one of history are 'block' and 'full',
If the recompute num layers of current strategy is larger than that of history and history OOM, prune it by performance.
6. If the recompute method and granularity of current strategy are 'block' and 'full', and one of history are 'block' and 'full',
If the recompute num layers of current strategy is smaller than that of history and history can run, prune it by memory.
"""
use_recompute = strategy["use_recompute"]
recompute_method = strategy["recompute_method"]
recompute_granularity = strategy["recompute_granularity"]
recompute_num_layers = strategy["recompute_num_layers"]

retrieval = beside(
[
"use_recompute",
"recompute_method",
"recompute_granularity",
"recompute_num_layers",
],
strategy,
history,
)
for item in retrieval:
# performance prune
# If history task can run without recompute, the task with recompute can be pruned
if not item["use_recompute"] and use_recompute and item["performance"]:
logger.info(
f"The strategy {strategy} has been pruned by use_recompute performance."
)
strategy["performance"] = item["performance"]
strategy["max_mem"] = item["max_mem"]
strategy["pruned"] = True
return True

if (use_recompute and item["use_recompute"]
and recompute_method == "block"
and recompute_method == item["recompute_method"]
and item["performance"]):
if recompute_num_layers > item["recompute_num_layers"]:
logger.info(
f"The strategy {strategy} has been pruned by block recompute_num_layers performance."
)
strategy["performance"] = item["performance"]
strategy["max_mem"] = item["max_mem"]
strategy["pruned"] = True
return True

if (use_recompute and item["use_recompute"]
and recompute_method == "uniform"
and recompute_method == item["recompute_method"]
and item["performance"]):
if recompute_num_layers > item["recompute_num_layers"]:
logger.info(
f"The strategy {strategy} has been pruned by uniform recompute_num_layers performance."
)
strategy["performance"] = item["performance"]
strategy["max_mem"] = item["max_mem"]
strategy["pruned"] = True
return True
# memory prune
if not use_recompute and item["use_recompute"] and item[
"max_mem"] == "OOM":
logger.info(
f"The strategy {strategy} has been pruned by use_recompute memory."
)
strategy["max_mem"] = "OOM"
strategy["performance"] = None
strategy["pruned"] = True
return True

if (use_recompute and item["use_recompute"]
and recompute_method == "uniform"
and recompute_method == item["recompute_method"]):
if (recompute_num_layers > item["recompute_num_layers"]
and item["max_mem"] == "OOM"):
logger.info(
f"The strategy {strategy} has been pruned by uniform recompute_num_layers memory."
)
strategy["max_mem"] = "OOM"
strategy["performance"] = None
strategy["pruned"] = True
return True

if (use_recompute and item["use_recompute"]
and recompute_method == "block"
and recompute_method == item["recompute_method"]):
if (recompute_num_layers < item["recompute_num_layers"]
and item["max_mem"] == "OOM"):
logger.info(
f"The strategy {strategy} has been pruned by block recompute_num_layers memory."
)
strategy["max_mem"] = "OOM"
strategy["performance"] = None
strategy["pruned"] = True
return True
return False


@register
def prune_by_sequence_parallel(config, strategy, history=[]):
"""Prune strategy by sequence_parallel."""
sequence_parallel = strategy["sequence_parallel"]
retrieval = beside(["sequence_parallel"], strategy, history)
if retrieval:
for item in retrieval:
# performance prune
if item["sequence_parallel"] and item[
"performance"] and not sequence_parallel:
logger.info(
f"The strategy {strategy} has been pruned by sequence_parallel performance."
)
strategy["performance"] = item["performance"]
strategy["max_mem"] = item["max_mem"]
strategy["pruned"] = True
return True
# memory prune
if item["sequence_parallel"] and item[
"max_mem"] == "OOM" and not sequence_parallel:
logger.info(
f"The strategy {strategy} has been pruned by sequence_parallel memory."
)
strategy["max_mem"] = "OOM"
strategy["performance"] = None
strategy["pruned"] = True
return True
return False
20 changes: 20 additions & 0 deletions flagscale/auto_tuner/prune/pruner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from .history import _HISTORY_BASED_PRUNE_FUNC


class Pruner:

def __init__(self, config):
self.config = config
self.pruned_count = 0

def prune(self, strategy, history=[]):
"""Prune strategy based on history recorded strategies."""
not_run = False
for func in _HISTORY_BASED_PRUNE_FUNC:
if func(self.config, strategy, history):
not_run = True
break
history.append(strategy)
if not_run:
self.pruned_count += 1
return not_run
1 change: 1 addition & 0 deletions flagscale/auto_tuner/record/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .recorder import Recorder
Loading

0 comments on commit ac373cb

Please sign in to comment.