Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Loading models doesn't work #1

Open
ArnoDekkersLos opened this issue Apr 24, 2019 · 1 comment
Open

Loading models doesn't work #1

ArnoDekkersLos opened this issue Apr 24, 2019 · 1 comment

Comments

@ArnoDekkersLos
Copy link

Hoping that this project is not abandoned and you're willing to patch this:
When trying to load a saved model using the Agent.load_model method it will throw the exception: 'keras load ValueError: Unknown loss function:loss'

The regular solution is to change the line:
self.actor_network = load_model(self.dic_path["PATH_TO_MODEL"], "%s_actor_network.h5")
(which I already changed to): self.actor_network = load_model("ppo/actor_network.h5")
to:
self.actor_network = load_model("ppo/actor_network.h5", custom_objects={'loss': self.loss})

However because the loss function is an inner function that cannot be called. When trying to use use proximal_policy_optimization_loss(which generates the loss function) instead it'll throw the exception: 'AttributeError: 'function' object has no attribute 'get_shape'

I've been trying to fix this by:
loading weights rather then the model
creating a lose loss function using self.parameters within
creating a lose loss function and use lampda https://stackoverflow.com/a/54177997/8579225

but I can't seem to fix things.
Hope you are willing to help me out with this.

@navallo
Copy link

navallo commented Jul 15, 2019

It's really late but here is a solution
1, loading weights rather than model
2, use 'build_network_from_copy' rather than 'deepcopy'

That is:

    def save_model_weights(self, file_name):
        self.actor_network.save_weights(os.path.join(self.dic_path["PATH_TO_MODEL"], "%s_actor_weights.h5" % file_name))
        self.critic_network.save_weights(os.path.join(self.dic_path["PATH_TO_MODEL"], "%s_critic_weights.h5" % file_name))

    def load_model_weights(self, file_name):
        self.actor_network.load_weights(os.path.join(self.dic_path["PATH_TO_MODEL"], "%s_actor_weights.h5" % file_name))
        self.critic_network.load_weights(os.path.join(self.dic_path["PATH_TO_MODEL"], "%s_critic_weights.h5" % file_name))
        self.actor_old_network = self.build_network_from_copy(self.actor_network)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants