From 4ed3353972b8f87f7f19f1cacb67ca47473b3539 Mon Sep 17 00:00:00 2001 From: Alexandre Payot Date: Tue, 24 Jan 2023 17:37:39 +0000 Subject: [PATCH] Support simpler syntax for specifying pipeline splits --- optimum/graphcore/modeling_utils.py | 38 +++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/optimum/graphcore/modeling_utils.py b/optimum/graphcore/modeling_utils.py index 1e17ccbf6..5c9b91524 100644 --- a/optimum/graphcore/modeling_utils.py +++ b/optimum/graphcore/modeling_utils.py @@ -174,6 +174,44 @@ def ipu_config(self, value: IPUConfig): raise TypeError(f"ipu_config must be an instance of IPUConfig, but {type(value)} was provided") self._ipu_config = value + def add_block(self, layer_accessor: str, ipu_id: int): + """Adds a `poptorch.BeginBlock` to the layer pointed to by the string in + `layer_accessor`. + + Args: + layer_accessor:A string which describes how to access a layer in a model hierarchy + in the form "[]....". The string must be the same format as python + code to access the layer. A valid example is "up_blocks[0].resnets[2]". + ipu_id: The ID of the IPU on which to place the layer (0 indexed). + + Returns: + Self + """ + item = self + parent = None + index = None + attr = None + # Get the layer described by accessor + for attr in layer_accessor.split("."): + index = None + parent = item + attr, *indices = attr.replace("]", "").split("[") + item = getattr(item, attr) + for index in indices: + parent = item + item = item[int(index)] + + # Set the begin block in place of the layer + if parent is None or index is None or attr is None: + raise RuntimeError("Parent ") + if index is not None: + # If there is an index we need to set an item in a list + parent[int(index)] = poptorch.BeginBlock(item, layer_accessor, ipu_id=ipu_id) + else: + # if there is no index we set the attribute of the parent object + setattr(parent, attr, poptorch.BeginBlock(item, layer_accessor, ipu_id=ipu_id)) + return self + def parallelize(self): """Transforms the model to run in an IPU pipeline.""" self._hooks = []