Skip to content

Commit

Permalink
reward classifier fix checkpoint save load
Browse files Browse the repository at this point in the history
Signed-off-by: youliang <[email protected]>
  • Loading branch information
youliangtan committed May 31, 2024
1 parent d399ba1 commit 22abe93
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 3 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pickle as pkl
import jax
from jax import numpy as jnp
import flax
import flax.linen as nn
from flax.training import checkpoints
import optax
Expand Down Expand Up @@ -161,6 +162,8 @@ def loss_fn(params):
f"Epoch: {epoch+1}, Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.4f}"
)

# this is used to save the without the orbax checkpointing
flax.config.update('flax_use_orbax_checkpointing', False)
checkpoints.save_checkpoint(
FLAGS.classifier_ckpt_path,
classifier,
Expand Down
4 changes: 3 additions & 1 deletion examples/async_cable_route_drq/train_reward_classifier.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pickle as pkl
import jax
from jax import numpy as jnp
import flax
import flax.linen as nn
from flax.training import checkpoints
import optax
Expand Down Expand Up @@ -161,7 +162,8 @@ def loss_fn(params):
print(
f"Epoch: {epoch+1}, Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.4f}"
)

# this is used to save the without the orbax checkpointing
flax.config.update('flax_use_orbax_checkpointing', False)
checkpoints.save_checkpoint(
FLAGS.classifier_ckpt_path,
classifier,
Expand Down
5 changes: 3 additions & 2 deletions serl_launcher/serl_launcher/networks/reward_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from flax.training.train_state import TrainState
from flax.training import checkpoints
import optax
from typing import Callable, Dict, List
from typing import Callable, Dict, List, Optional


from serl_launcher.vision.resnet_v1 import resnetv1_configs, PreTrainedResNetEncoder
Expand Down Expand Up @@ -94,6 +94,7 @@ def load_classifier_func(
sample: Dict,
image_keys: List[str],
checkpoint_path: str,
step: Optional[int] = None
) -> Callable[[Dict], jnp.ndarray]:
"""
Return: a function that takes in an observation
Expand All @@ -103,7 +104,7 @@ def load_classifier_func(
classifier = checkpoints.restore_checkpoint(
checkpoint_path,
target=classifier,
step=100,
step=step,
)
func = lambda obs: classifier.apply_fn(
{"params": classifier.params}, obs, train=False
Expand Down

0 comments on commit 22abe93

Please sign in to comment.