You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
) in NNX. Although Orbax has added individual support for the new type (see google/orbax#620), saving nnx.state(model) that includes dtype=key<fry> doesn't seem to be possible.
The text was updated successfully, but these errors were encountered:
Currently, it doesn't seem possible to straightforwardly checkpoint (with Orbax) an NNX module that includes random keys (like with dropout), see google/orbax#1105 (comment). This seems to be due to the new JAX random key type (https://jax.readthedocs.io/en/latest/jep/9263-typed-keys.html), which is used here (
flax/flax/nnx/rnglib.py
Line 186 in fc19c5d
nnx.state(model)
that includesdtype=key<fry>
doesn't seem to be possible.The text was updated successfully, but these errors were encountered: