Replies: 1 comment 5 replies
-
Hey, we currently don't have a def load_model(path: str) -> MLP:
# create that model with abstract shapes
model = nnx.eval_shape(lambda: create_model(0))
state = nnx.state(model)
# Load the parameters
checkpointer = orbax.PyTreeCheckpointer()
state = checkpointer.restore(f'{path}/state', item=state)
# update the model with the loaded state
nnx.update(model, state)
return model This is taken from this 08_save_load_checkpoints.py. |
Beta Was this translation helpful? Give feedback.
5 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Hi,
I want to be able to serialize multiple different
nnx
models to disk (not just the weights but also the full layer structure).This is helpful when trying out a bunch of different architectures that I trained beforehand and just want to test in eval/inference mode.
Currently, I am using Orbax to save the model train state. But this requires that the trainstate structure is created before loading the checkpoint. I am doing something like this:
What I would like to do is something like this:
This may be possible with pickling the create_model function (but could not work because of lambda functions in the create_model function), but I guess this is not the idiomatic way.
In torch you can do:
I basically would like to use the same kind of API as in torch but with
flax/nnx
.Beta Was this translation helpful? Give feedback.
All reactions