Skip to content

Commit

Permalink
Tweak the jax example
Browse files Browse the repository at this point in the history
Signed-off-by: Fabrice Normandin <[email protected]>
  • Loading branch information
lebrice committed Jun 3, 2024
1 parent d5b400c commit 1a71d48
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 7 deletions.
24 changes: 18 additions & 6 deletions project/algorithms/jax_algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,15 +95,13 @@ def __init__(
self.hp: JaxAlgorithm.HParams
key = jax.random.key(self.hp.seed)
# todo: Extract out the "network" portion, and probably use something like flax for it.
params = ParamsTuple(
self.jax_params = ParamsTuple(
w1=jax.random.uniform(key=jax.random.fold_in(key, 1), shape=(input_dims, 128)),
b1=jax.random.uniform(key=jax.random.fold_in(key, 2), shape=(128,)),
w2=jax.random.uniform(key=jax.random.fold_in(key, 3), shape=(128, output_dims)),
b2=jax.random.uniform(key=jax.random.fold_in(key, 4), shape=(output_dims,)),
)
self.params = torch.nn.ParameterList(
[torch.nn.Parameter(v, requires_grad=True) for v in map(jax_to_torch, params)]
)

self.forward_pass = forward_pass
self.backward_pass = value_and_grad(self.forward_pass)

Expand All @@ -114,6 +112,19 @@ def __init__(
# We will do the backward pass ourselves, and PL will synchronize stuff between workers, etc.
self.automatic_optimization = False

@property
def jax_params(self):
# View the torch parameters as jax Arrays
return ParamsTuple(**{k: torch_to_jax(p.data) for k, p in self.named_parameters()})

@jax_params.setter
def jax_params(self, value: ParamsTuple[jax.Array]):
for k, jax_v in value._asdict().items():
assert isinstance(jax_v, jax.Array)
torch_v = jax_to_torch(jax_v)
p: torch.nn.Parameter = torch.nn.Parameter(torch_v, requires_grad=True)
self.register_parameter(k, p)

def shared_step(
self, batch: tuple[torch.Tensor, torch.Tensor], batch_index: int, phase: PhaseStr
):
Expand All @@ -123,8 +134,7 @@ def shared_step(
# View/"convert" the torch inputs to jax Arrays.
jax_x, jax_y = torch_to_jax(torch_x), torch_to_jax(torch_y)

# View the parameters as jax Arrays
jax_params = ParamsTuple(*map(torch_to_jax, self.parameters()))
jax_params = self.jax_params

if phase != "train":
# Only use the forward pass.
Expand All @@ -136,7 +146,9 @@ def shared_step(
# Perform the backward pass
(loss, logits), jax_grads = self.backward_pass(jax_params, jax_x, jax_y)
with torch.no_grad():
# 'convert' the gradients to pytorch
torch_grads = map(jax_to_torch, jax_grads)
# Update the torch parameters tensors in-place using the jax grads.
for param, grad in zip(self.parameters(), torch_grads):
if param.grad is None:
param.grad = grad
Expand Down
3 changes: 2 additions & 1 deletion project/datamodules/vision/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,9 @@ def __init__(
"""

super().__init__()
from project.configs.datamodule import DATA_DIR

self.data_dir = data_dir if data_dir is not None else os.getcwd()
self.data_dir = data_dir if data_dir is not None else DATA_DIR
self.val_split = val_split
if num_workers is None:
num_workers = num_cpus_on_node()
Expand Down

0 comments on commit 1a71d48

Please sign in to comment.