Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flux.params for primitive types #1991

Open
FelixBenning opened this issue Jun 7, 2022 · 1 comment
Open

Flux.params for primitive types #1991

FelixBenning opened this issue Jun 7, 2022 · 1 comment

Comments

@FelixBenning
Copy link

Flux.params does not work for primitive types, e.g.

Screenshot 2022-06-07 at 10 14 45

I assume this is because Zygote uses the pointer to the arrays to identify the array and differentiate with regard to every usage of that pointer. Primitive types do not have a pointer stored in the struct but the element itself so this method breaks.

Ideally Primitive types are supported (the pointer is the pointer to the model struct + offset)

But it might be more realistic to use the output of Flux.functor(model) to print a warning about any primitive (or otherwise incompatible) types when Flux.params is used.

@mcabbott
Copy link
Member

mcabbott commented Jun 7, 2022

Yes, Zygote's implicit mode does not track scalars. It works by objectid, which is a stable identity for things like arrays. The plan is for Flux to stop using this, in favour of explicit mode (things like gradient(m -> loss(m, x, y), model), without Params / Grads) which has no such restriction, and no global variables.

The issue about this transition is #1986. At the moment, Optimisers.jl won't update scalars, it only acts on arrays of numbers. This can be widened, the issue is that Functors.jl again uses objectid to decide whether two branches of the tree are identical, thus several scalars initially 0.0 will be permanently locked together. FluxML/Functors.jl#39 is one attempt to fix this.

might be more realistic to use the output of Flux.functor(model) to print a warning about any primitive (or otherwise incompatible) types

The reason it doesn't, BTW, is that models do often contain numbers we don't want to regard as parameters -- such as the strides of a convolution.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants