Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bias and kernel params are put on different gpu devices #4116

Open
YunxiTang opened this issue Aug 6, 2024 · 1 comment
Open

bias and kernel params are put on different gpu devices #4116

YunxiTang opened this issue Aug 6, 2024 · 1 comment

Comments

@YunxiTang
Copy link

YunxiTang commented Aug 6, 2024

System information

  • OS Platform and Distribution: Linux Ubuntu 20.04
  • Flax, jax, jaxlib versions: flax -> 0.6.11, jax -> 0.4.13, jaxlib -> 0.4.13+cuda11.cudnn86
  • Python version: Python 3.9.19
  • GPU/TPU model and memory: GPU4090 with 24GB
  • CUDA version: cuda 11.8

Problem you have encountered:

When I try to initialize a Flax model on a specific gpu device (for example, gpu 1), the bias and kernel params are located on different gpu devices.

What you expected to happen:

The bias and kernel params should be put on the same gpu device.

Steps to reproduce:

  import jax
  import jax.numpy as jnp
  from jax import tree_util
  from flax import linen as nn

  device = jax.devices("gpu")[1]

  class MyModel(nn.Module):
      @nn.compact
      def __call__(self, x):
          x = nn.Conv(64, (3, 3), 1, name='conv1')(x)
          x = nn.relu(x)
          return x

  rng = jax.random.PRNGKey(0)
  rng = jax.device_put(rng, device)
  dummy_input = jax.device_put(jnp.ones((5, 64, 64, 32)), device) 

  model = MyModel()  
  model_params = model.init({'params': rng}, dummy_input)
  # model_params = tree_util.tree_map(lambda x: jax.device_put(x, device), model_params)
  print(tree_util.tree_map(lambda x: (x.device()), model_params))

The output is

FrozenDict({
    params: {
        conv1: {
            bias: gpu(id=0),
            kernel: gpu(id=1),
        },
    },
})

Thanks!

@MasterSkepticista
Copy link

Hi @YunxiTang, I am able to reproduce this issue.

In practice, I have seen flax models initialized on cpu, and migrated/replicated to devices later. Two examples:

  1. Migrating params post-initialization to GPU.
    # Optional: Init on `cpu`.
    model_params = jax.jit(model.init, backend="cpu")({'params': rng}, dummy_input)
    model_params = jax.device_put(model_params, device)
    jax.tree.map(lambda p: p.device, model_params)
    # {'params': {'conv1': {'bias': CudaDevice(id=1), 'kernel': CudaDevice(id=1)}}}
  2. Using jax.default_device scope.
    with jax.default_device(device):
        model_params = model.init({'params': rng}, dummy_input)
        print(tree_util.tree_map(lambda x: (x.device), model_params))
        # {'params': {'conv1': {'bias': CudaDevice(id=1), 'kernel': CudaDevice(id=1)}}}

I will let Flax team comment on the default behavior in your case.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants