Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add NNX support for legacy jax.random.PRNGKey() #4231

Open
cisprague opened this issue Sep 26, 2024 · 1 comment
Open

Add NNX support for legacy jax.random.PRNGKey() #4231

cisprague opened this issue Sep 26, 2024 · 1 comment

Comments

@cisprague
Copy link

Currently, it doesn't seem possible to straightforwardly checkpoint (with Orbax) an NNX module that includes random keys (like with dropout), see google/orbax#1105 (comment). This seems to be due to the new JAX random key type (https://jax.readthedocs.io/en/latest/jep/9263-typed-keys.html), which is used here (

key = jax.random.key(value)
) in NNX. Although Orbax has added individual support for the new type (see google/orbax#620), saving nnx.state(model) that includes dtype=key<fry> doesn't seem to be possible.

@cgarciae
Copy link
Collaborator

cgarciae commented Sep 27, 2024

Hey @cisprague, maybe you can convert to the old format before serializing?
You could use something like:

def get_key_data(x):
  # use jax.random.key_data
  
serializable_state = jax.tree.map(get_key_data, state)

See PRNGKeys

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants