Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Uses torch.fx to parallelize and transform pipelined models #193

Open
wants to merge 33 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
9118397
[WIP] fx formalism for BERT
michaelbenayoun Jun 29, 2022
b1e23f3
Transformations are working
michaelbenayoun Jul 7, 2022
87e9273
[WIP] fx formalism for BERT
michaelbenayoun Jul 8, 2022
7009e6a
Temp
michaelbenayoun Jul 21, 2022
d7df7b5
Trainer with symbolically traced models
michaelbenayoun Aug 3, 2022
29e48ec
Fix issues
michaelbenayoun Aug 3, 2022
1ac1e9f
FX parallelize for T5
michaelbenayoun Aug 5, 2022
5cbb50e
[WIP] BART
michaelbenayoun Aug 10, 2022
b54758b
[WIP] Roberta
michaelbenayoun Aug 11, 2022
3207fec
Fix BART after rebasing
michaelbenayoun Aug 22, 2022
2f4b968
Small fixes
michaelbenayoun Aug 22, 2022
28ff013
[WIP] working or almost working version for everyone
michaelbenayoun Aug 26, 2022
22d65c7
Removed unused code in T5
michaelbenayoun Aug 30, 2022
41b7814
Deberta recomputation checkpoint
michaelbenayoun Aug 31, 2022
c2eed06
Fix BartForSequenceClassification
michaelbenayoun Sep 20, 2022
2a5c716
Fixes
michaelbenayoun Oct 11, 2022
aa4e93e
Fixes
michaelbenayoun Oct 11, 2022
7fd6bc9
Fixes
michaelbenayoun Oct 11, 2022
b896862
Fix issues
michaelbenayoun Oct 13, 2022
ddfd342
[WIP] tests/test_pipelined_models.py
michaelbenayoun Oct 18, 2022
0b23428
All tests but test_examples are passing
michaelbenayoun Oct 19, 2022
a7abe14
Add TransformationManager
michaelbenayoun Oct 20, 2022
0fe6da7
Make style
michaelbenayoun Oct 20, 2022
ba03938
Change deberta
michaelbenayoun Oct 20, 2022
89f04c8
Add missing files
michaelbenayoun Oct 21, 2022
523f2fc
Format
michaelbenayoun Oct 21, 2022
84d55bb
some stuff
michaelbenayoun Oct 21, 2022
8b2e8fb
Fix BART
michaelbenayoun Oct 28, 2022
c94d930
changes
michaelbenayoun Nov 7, 2022
a582a46
GPT-2 Fix
michaelbenayoun Nov 8, 2022
43f2001
commit for diff
michaelbenayoun Nov 10, 2022
08dc5d1
Fix transformation => Deberta compiles
michaelbenayoun Nov 14, 2022
586fde2
Fixed BART
michaelbenayoun Nov 14, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/language-modeling/ipu_config.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{
"embedding_serialization_factor": 1,
"embedding_serialization_factor": 2,
"recompute_checkpoint_every_layer": true,
"optimizer_state_offchip": true,
"replicated_tensor_sharding": true,
Expand Down
31 changes: 31 additions & 0 deletions optimum/graphcore/fx/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# coding=utf-8
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .transformation_manager import DEFAULT_TRANSFORMATION_MANAGER, TransformationManager # noqa
from .transformations import ( # noqa
AddPoptorchBlock,
AddPoptorchBlocksInSeries,
AutoCast,
ClipValues,
ClipValuesSymmetric,
LinearToSerializedLinear,
OutlineAttribute,
RecomputationCheckpoint,
ShareEmbeddingComputation,
TieWeights,
TupleOutput,
VocabEmbeddingToSerializedEmbedding,
)
from .utils import symbolic_trace_pipelined_model # noqa
129 changes: 129 additions & 0 deletions optimum/graphcore/fx/transformation_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
# coding=utf-8
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Defines the class that manages which transformations to apply according to some optimization level."""

import copy
import functools
import operator
from typing import Iterator, List, Tuple, Union

import torch

from ...fx.optimization import (
ChangeTrueDivToMulByInverse,
MergeLinears,
ReversibleTransformation,
Transformation,
compose,
)
from .transformations import ClipValues, ClipValuesSymmetric


class TransformationManager:
def __init__(self, *transformations: Tuple[int, "Transformation"]):
self._signatures = {
0: set(),
1: set(),
2: set(),
}
self._transformations = {
0: [],
1: [],
2: [],
}
self.register(*transformations)

def without(self, *args: Transformation) -> "TransformationManager":
clone = copy.deepcopy(self)
clone.unregister(*args)
return clone

def register(self, *transformations: Tuple[int, Transformation]):
for (opt_level, t) in transformations:
for k, signatures in self._signatures.items():
if t.signature in signatures:
raise RuntimeError(
f"The transformation {t} has already been registered at optimization level = {k}."
)
self._signatures[opt_level].add(t.signature)
self._transformations[opt_level].append(t)

def unregister(self, *transformations: Transformation):
for transformation_to_unregister in transformations:
level = None
sig = transformation_to_unregister.signature
for opt_level, signatures in self._signatures.items():
if sig in signatures:
level = opt_level
signatures.remove(sig)
if level is not None:
idx_to_pop = None
for idx, t in enumerate(self._transformations[level]):
if t.signature == sig:
idx_to_pop = idx
break
self._transformations[level].pop(idx_to_pop)

def _check_optimization_level(self, optimization_level):
if optimization_level not in [0, 1, 2]:
raise ValueError(f"The optimization level must be either 0, 1 or 2, but {optimization_level} was given.")

def _get_transformations(
self, optimization_level: int, as_list: bool = False
) -> Union[Iterator[Transformation], List[Transformation]]:
self._check_optimization_level(optimization_level)
iterator = functools.reduce(
lambda x, y: x + y, (self._transformations[i] for i in range(optimization_level + 1)), []
)
return iterator if as_list is False else list(iterator)

def get_transformations(self, optimization_level: int) -> List[Transformation]:
return self._get_transformations(optimization_level, as_list=True)

def get_non_reversible_transformations(self, optimization_level: int) -> List[Transformation]:
return [
t for t in self._get_transformations(optimization_level) if not isinstance(t, ReversibleTransformation)
]

def get_reversible_transformations(self, optimization_level: int) -> List[ReversibleTransformation]:
return [t for t in self._get_transformations(optimization_level) if isinstance(t, ReversibleTransformation)]

def _compose_transformations(
self, optimization_level: int, transformations: List[Transformation]
) -> Transformation:
return compose(*transformations) if transformations else lambda x: x

def compose_transformations(self, optimization_level: int) -> Transformation:
return self._compose_transformations(optimization_level, self.get_transformations(optimization_level))

def compose_non_reversible_transformations(self, optimization_level: int) -> Transformation:
return self._compose_transformations(
optimization_level, self.get_non_reversible_transformations(optimization_level)
)

def compose_reversible_transformations(self, optimization_level: int) -> ReversibleTransformation:
return self._compose_transformations(
optimization_level, self.get_reversible_transformations(optimization_level)
)


DEFAULT_TRANSFORMATION_MANAGER = TransformationManager(
(1, ChangeTrueDivToMulByInverse()),
(1, MergeLinears()),
# (1, FuseBiasInLinear()),
# Those change the computation, but are actually needed for fp16 stability.
(0, ClipValuesSymmetric(1e4, include_targets=(torch.add, torch.mul, operator.add, operator.mul))),
(0, ClipValues(1e-4, float("inf"), include_targets=(torch.nn.LayerNorm,))),
)
Loading