Skip to content

Commit

Permalink
prepare lora for state dict change (#791)
Browse files Browse the repository at this point in the history
  • Loading branch information
dlwh authored Nov 6, 2024
1 parent 0e6a6b4 commit 1c43256
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions tests/test_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ def __call__(self, x):
@staticmethod
def init(*, key):
k1, k2 = jax.random.split(key)
first = hnn.Linear.init(In, Mid, key=k1)
second = hnn.Linear.init(Mid, In, key=k2)
first = hnn.Linear.init(In, Mid, key=k1, out_first=True)
second = hnn.Linear.init(Mid, In, key=k2, out_first=True)
return Module(first, second)

Layers = hax.Axis("Layers", 3)
Expand All @@ -91,7 +91,7 @@ def init(*, key):
assert loraized.stacked.first.lora.lora_A.weight.axes == (Layers, hax.Axis("LORA_R", 8), In)
assert loraized.stacked.first.lora.lora_B.weight.axes == (Layers, Mid, hax.Axis("LORA_R", 8))

assert loraized.stacked.second.weight.axes == (Layers, Mid, In)
assert loraized.stacked.second.weight.axes == (Layers, In, Mid)
input = hax.random.normal(k0, (In,))
assert not hax.all(hax.isclose(module.fold(input), loraized.fold(input)))

Expand Down

0 comments on commit 1c43256

Please sign in to comment.