We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
When I try to initialize a Flax model on a specific gpu device (for example, gpu 1), the bias and kernel params are located on different gpu devices.
gpu 1
The bias and kernel params should be put on the same gpu device.
import jax import jax.numpy as jnp from jax import tree_util from flax import linen as nn device = jax.devices("gpu")[1] class MyModel(nn.Module): @nn.compact def __call__(self, x): x = nn.Conv(64, (3, 3), 1, name='conv1')(x) x = nn.relu(x) return x rng = jax.random.PRNGKey(0) rng = jax.device_put(rng, device) dummy_input = jax.device_put(jnp.ones((5, 64, 64, 32)), device) model = MyModel() model_params = model.init({'params': rng}, dummy_input) # model_params = tree_util.tree_map(lambda x: jax.device_put(x, device), model_params) print(tree_util.tree_map(lambda x: (x.device()), model_params))
The output is
FrozenDict({ params: { conv1: { bias: gpu(id=0), kernel: gpu(id=1), }, }, })
Thanks!
The text was updated successfully, but these errors were encountered:
Hi @YunxiTang, I am able to reproduce this issue.
In practice, I have seen flax models initialized on cpu, and migrated/replicated to devices later. Two examples:
cpu
# Optional: Init on `cpu`. model_params = jax.jit(model.init, backend="cpu")({'params': rng}, dummy_input) model_params = jax.device_put(model_params, device) jax.tree.map(lambda p: p.device, model_params) # {'params': {'conv1': {'bias': CudaDevice(id=1), 'kernel': CudaDevice(id=1)}}}
jax.default_device
with jax.default_device(device): model_params = model.init({'params': rng}, dummy_input) print(tree_util.tree_map(lambda x: (x.device), model_params)) # {'params': {'conv1': {'bias': CudaDevice(id=1), 'kernel': CudaDevice(id=1)}}}
I will let Flax team comment on the default behavior in your case.
Sorry, something went wrong.
No branches or pull requests
System information
Problem you have encountered:
When I try to initialize a Flax model on a specific gpu device (for example,
gpu 1
), the bias and kernel params are located on different gpu devices.What you expected to happen:
The bias and kernel params should be put on the same gpu device.
Steps to reproduce:
The output is
Thanks!
The text was updated successfully, but these errors were encountered: