Skip to content

Commit

Permalink
Feat (offload/fx): better buffer/params + call_functional
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Jan 30, 2024
1 parent 883e193 commit b008e18
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 24 deletions.
47 changes: 24 additions & 23 deletions src/brevitas_examples/optimum/offloading_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from accelerate.utils import get_max_layer_size
from accelerate.utils import get_max_memory
from accelerate.utils import send_to_device
from accelerate.utils.modeling import named_module_tensors
import torch

from brevitas.graph.utils import get_module
Expand Down Expand Up @@ -96,17 +97,29 @@ def infer_fx_auto_device_map(
current_device = 0
current_memory_used = 0

call_module_list = []
call_list = []
buffers_attributes = [n for n, _ in list(named_module_tensors(model, recurse=True))]
all_modules = [n.target for n in list(model.graph.nodes) if n.op == 'call_module']
for node in model.graph.nodes:
# If it's a module, we simply offload it or move it to the desired device
if node.op == 'call_module':
name = node.target
module = get_module(model, node.target)
call_module_list.append((name, module))
call_list.append((name, module))
# If it's get_attr, we check what module it is attached to
# In case the module is not part of call_module, we specifically allocate the buffer/parameter on some device
# NB: This does NOT guarantee that it will be aligned with whatever input tensor it will be combined with
# For that, there is a separate function
if node.op == 'get_attr':
target = node.target
if target in buffers_attributes:
module_name = '.'.join(target.split('.')[:-1])
if module_name not in all_modules:
module = get_module(model, target)
call_list.append((target, module))

# Direct submodules and parameters
modules_to_treat = (
list(model.named_parameters(recurse=False)) + call_module_list +
list(model.named_buffers(recurse=False)))
modules_to_treat = call_list
# Initialize maximum largest layer, to know which space to keep in memory
max_layer_size, max_layer_names = get_max_layer_size(modules_to_treat, module_sizes, [])

Expand Down Expand Up @@ -249,14 +262,17 @@ def infer_fx_auto_device_map(
current_memory_used += module_size
device_map[name] = devices[current_device]

# If we have only one device, we simplify the device_map
if len(set(device_map.values())) == 1:
device_map = {'': list(device_map.values())[0]}
return device_map


def offload_call_function(model, max_memory, device_map):
max_memory = get_max_memory(max_memory)
devices = list(max_memory.keys())
def offload_call_function(model, device_map):

# If we only have one device, offloading is not needed
if len(set(device_map.values())) == 1:
return

for node in model.graph.nodes:
if node.op == 'call_function':
Expand Down Expand Up @@ -285,18 +301,3 @@ def new_func(*args, old_callable=node.target, **kwargs):

model.recompile()
model.graph.lint()

## All keys that have been deal through `call_function` are on CPU by default, and moved when needed
all_model_tensors = [name for name, _ in model.state_dict().items()]
for module_name in device_map.keys():
if module_name == "":
all_model_tensors.clear()
break
else:
all_model_tensors = [
name for name in all_model_tensors
if not name == module_name and not name.startswith(module_name + ".")]
for tensor in all_model_tensors:
cpu_index = devices.index('cpu')
device_map[tensor] = devices[cpu_index]
return device_map
2 changes: 1 addition & 1 deletion src/brevitas_examples/optimum/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def maybe_offload_weights_to_cpu(model, is_fx=False):
memory_map = {**cpu_device_map, **cuda_device_map}
if is_fx:
device_map = infer_fx_auto_device_map(model, memory_map)
device_map = offload_call_function(model, memory_map, device_map)
offload_call_function(model, device_map)
else:
device_map = infer_auto_device_map(
model, memory_map, no_split_module_classes=model._no_split_modules)
Expand Down

0 comments on commit b008e18

Please sign in to comment.