-
Notifications
You must be signed in to change notification settings - Fork 56
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature: Allow serialization of custom networks #284
Conversation
This commit adds utility functions and extends existing networks to enable serialization of complete networks when custom network types are passed as arguments (e.g., for sub-networks in coupling flows). The main complications were: * Objects of type `type` (uninstantiated classes) cannot be serialized using `keras.saving.serialize_keras_object`, as the have no `get_config` function. * We want to support both strings and types as parameters, leading to the need to distinguish those during manual serialization/deserialization. * Auto-discovery of __init__ parameters is only active when `get_config` is not overridden, necessitating to manually store the configuration for serialization. For storing the types, we use `keras.saving.get_registered_name`, which can be reconstructed at deserialization using `keras.saving.get_registered_object`. Handling the different cases is moved the utility functions `(de)serialize_val_or_type`, which uses a naming scheme to determine which deserialization method to use. The same setup can be extended to other custom types, e.g. distributions.
If you agree with the general design, I think the main question is for which parameters we want to use this. For now, I only implemented it for subnets, but it should be easily transferable to other parameters like distributions. Do we want to apply this broadly before merging this, or do we do it incrementally? I am not totally happy with my adaptations to the test, but I did not see a better way to solve this. Simply adding the subnet fixture to all inference nets doubles the test time for the networks, approximately from 5min to 10min. As this is only relevant for serialization, I deemed this not acceptable and split it up, leading to a reduced runtime but less readable test code. Any ideas to make those tests prettier/easier maintainable are welcome. |
Thank you for working on this feature! I think it is very important to have! I am not the best person to ask about the internals so I am refering to @LarsKue and @stefanradev93 for a proper review. |
Hi Valentin, I like the general design. I would only opt for a more verbose name of the utility: |
Hey Stefan, thanks for taking a look, I have renamed the functions. I think for this set of changes we are ready to merge. If you have time you can comment on the following: |
Thanks! I think the most pertinent usage will be for subnets. We can always enable it for distros later if there is demand. |
This PR addresses #228. It introduces utility functions and extends existing networks to
enable serialization of complete networks when custom network types are
passed as arguments (e.g., for sub-networks in coupling flows).
The main complications were:
Objects of type
type
(uninstantiated classes) cannot be serialized usingkeras.saving.serialize_keras_object
, as the have noget_config
function.keras.saving.get_registered_name
has to be used.We want to support both strings and types as parameters, leading to the need to distinguish those during manual serialization/deserialization.
Auto-discovery of init parameters is only active when
get_config
is not overridden, necessitating to manually store the configuration for serialization.For storing the types, we use
keras.saving.get_registered_name
, which can be reconstructed at deserialization usingkeras.saving.get_registered_object
.Handling the different cases is moved the utility functions
(de)serialize_val_or_type
, which uses a naming scheme to determine which deserialization method to use.The same setup can be extended to other custom types, e.g. distributions.