Replies: 2 comments
-
The more low-level save/load functions are in
Note that you can always pass |
Beta Was this translation helpful? Give feedback.
-
@jheek I didn't see any file IO level save / load in serialize. Just to/from bytes serialization. The idea was to break the file io level checkpointing functionality into save/load only and history mgmt. I realize bytes -> file isn't a big deal, but it's not something you want every user to re-write themselves... adding functionality people are used to in other libs you might want model zoo / download functionality added in a (non-training) checkpoint handling scope at some point. For checkpoint restore, using dataclass fields could be a good option. I'm aware of the current None behaviour. I was actually surprised when that first happened and I'm not yet convinced that use of None makes sense given that None is used for optional semantics in Python. Looking in the future to when Flax is fully fleshed out, it would seem the need/desire to have the raw state dict will be less common than say handling optional fields. I'd rather not write the code to handle raw state dicts if I didn't have to and most of my need to do that (at least right now) is for fwd/backward compat handling / optional fields |
Beta Was this translation helpful? Give feedback.
-
Working on a training script recently I've run into a few items related to checkpointing that I feel would improve usability.
save/load functions
Current save/restore checkpoint routines include code for managing checkpoint histories, sorting by step, restoring latest, etc. There is no basic save/load fn that can be used without those extras. I'm about to write my own checkpoint manager as I want to manage history based on eval metrics and I will have to re-write the save/load serialization bits in my own code.
It'd be nice to split the base save/load functionality from the history management.
strict flag
Working with training code based on ImageNet example here, I have a TrainState dataclass. Adding an extra field that I setup with safe defaults (so new training session should be backwards compat with old state) fails because the
from_state_dict
fn for struct.dataclass is quite strictAdding something like a strict flag like below would improve this for my use...
Beta Was this translation helpful? Give feedback.
All reactions