Replies: 2 comments 2 replies
-
I'm not sure what you mean with "all the different collections" the base Modules only use 'params', 'batch_stats' (BatchNorm) and 'cache' for decoding in autoregressive attention. The idea here is to not simply put anything that is not a paramater instead one big collection. Consider for example syncing the batch statistics across devices for a certain model. What if all its internal state was in 'state'. How do I now make sure that I only sync the batch statistics? |
Beta Was this translation helpful? Give feedback.
-
I see, thanks for the info! In general am wondering what is the current recommendation for the following scenario: Say you train Module A and serialize its variables to disk. Latter, create a Module B that uses Module A and want to inject the pre-trained parameters of A. Similarly, what if you want to do opposite, you train B and then want to extract just A for inference (e.g. A is a decoder and you want to generate samples). |
Beta Was this translation helpful? Give feedback.
-
Hey,
I was wondering what is the purpose of variable collections? I don't know what is the current state of transfer learning in Flax is but as it requires "parameters surgery" my guess is that having all these different collections makes getting previous parameters to their correct positions in the new architecture increasingly more challenging.
Proposal
Remove / hide the
Module.variable(col)
parameter and just have 2 canonical collections:params
: what we currently have for the trainable parametersstates
: name for the non-trainable parameters.Beta Was this translation helpful? Give feedback.
All reactions