-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #19 from OpenMOSS/opt-memory-speed
Resolves #10
- Loading branch information
Showing
11 changed files
with
238 additions
and
27 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
from transformer_lens import HookedTransformer | ||
import torch | ||
|
||
MODEL = "solu-2l" | ||
|
||
def time_diff(func): | ||
start_event = torch.cuda.Event(enable_timing=True) | ||
end_event = torch.cuda.Event(enable_timing=True) | ||
start_event.record() | ||
|
||
func() | ||
|
||
end_event.record() | ||
torch.cuda.synchronize() | ||
return start_event.elapsed_time(end_event) | ||
|
||
@torch.no_grad() | ||
def test_offload_params_after(): | ||
model = HookedTransformer.from_pretrained(MODEL, device="cuda") | ||
allocated_before = torch.cuda.memory_allocated(0) | ||
model.offload_params_after("blocks.0.hook_resid_post", torch.tensor([[0]], device="cuda")) | ||
allocated_after = torch.cuda.memory_allocated(0) | ||
assert allocated_after < allocated_before * 0.55 | ||
|
||
@torch.no_grad() | ||
def test_run_with_cache_until(): | ||
model = HookedTransformer.from_pretrained(MODEL, device="cuda") | ||
def forward(): | ||
model.run_with_cache("Hello, world!", names_filter=["blocks.0.hook_resid_post"]) | ||
forward_time = time_diff(forward) | ||
def forward_until(): | ||
model.run_with_cache_until("Hello, world!", names_filter=["blocks.0.hook_resid_post"]) | ||
forward_fake_time = time_diff(forward_until) | ||
assert forward_fake_time < forward_time * 0.7 | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
from transformer_lens.hook_points import HookedRootModule, HookPoint | ||
import torch | ||
import torch.nn as nn | ||
|
||
class Block(nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.subblock1 = nn.Linear(10, 10) | ||
self.subblock2 = nn.Linear(10, 10) | ||
self.activation = nn.ReLU() | ||
self.hook_mid = HookPoint() | ||
|
||
def forward(self, x): | ||
return self.subblock2(self.hook_mid(self.activation(self.subblock1(x)))) | ||
|
||
class TestModule(HookedRootModule): | ||
__test__ = False | ||
def __init__(self, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
self.blocks = nn.ModuleList([Block() for _ in range(3)]) | ||
self.embed = nn.Linear(1, 10) | ||
self.unembed = nn.Linear(10, 1) | ||
self.setup() | ||
|
||
def forward(self, x): | ||
x = self.embed(x) | ||
for block in self.blocks: | ||
x = block(x) | ||
return self.unembed(x) | ||
|
||
def test_run_with_cache_until(): | ||
model = TestModule() | ||
_, cache_before = model.run_with_cache(torch.tensor([1.]), names_filter=["blocks.0.hook_mid", "blocks.1.hook_mid"]) | ||
out, cache_after = model.run_with_cache_until(torch.tensor([1.]), names_filter=["blocks.0.hook_mid", "blocks.1.hook_mid"]) | ||
|
||
assert torch.allclose(cache_before["blocks.0.hook_mid"], cache_after["blocks.0.hook_mid"]) | ||
assert torch.allclose(cache_before["blocks.1.hook_mid"], cache_after["blocks.1.hook_mid"]) | ||
assert torch.allclose(cache_before["blocks.1.hook_mid"], out) | ||
|
||
def test_offload_params_after(): | ||
model = TestModule() | ||
_, cache_before = model.run_with_cache(torch.tensor([1.])) | ||
|
||
model.offload_params_after("blocks.1.hook_mid", torch.tensor([1.])) | ||
assert model.blocks[0].subblock1.weight is not None | ||
assert model.blocks[1].subblock1.weight is not None | ||
assert model.blocks[2].subblock1.weight is None | ||
assert model.unembed.weight is None | ||
|
||
_, cache_after = model.run_with_cache_until(torch.tensor([1.]), names_filter=["blocks.0.hook_mid", "blocks.1.hook_mid"]) | ||
assert torch.allclose(cache_before["blocks.0.hook_mid"], cache_after["blocks.0.hook_mid"]) | ||
assert torch.allclose(cache_before["blocks.1.hook_mid"], cache_after["blocks.1.hook_mid"]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.