Skip to content

Commit

Permalink
Fix encoder training
Browse files Browse the repository at this point in the history
  • Loading branch information
belerico committed Jul 17, 2023
1 parent 3f4be4d commit 5c28ae9
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion sheeprl/algos/sac_pixel/sac_pixel_continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
def train(
fabric: Fabric,
agent: SACPixelAgent,
encoder: Union[Encoder, _FabricModule],
decoder: Union[Decoder, _FabricModule],
actor_optimizer: Optimizer,
qf_optimizer: Optimizer,
Expand Down Expand Up @@ -101,7 +102,7 @@ def train(

# Update the decoder
if global_step % args.decoder_update_freq == 0:
hidden = agent.critic.encoder(normalized_obs)
hidden = encoder(normalized_obs)
reconstruction = decoder(hidden)
reconstruction_loss = (
F.mse_loss(preprocess_obs(data["observations"], bits=5), reconstruction) # Reconstruction
Expand Down Expand Up @@ -359,6 +360,7 @@ def main():
train(
fabric,
agent,
encoder,
decoder,
actor_optimizer,
qf_optimizer,
Expand Down

0 comments on commit 5c28ae9

Please sign in to comment.